From b057c52b0175a31ee8a786a780bd0eef785820ae Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 18 Feb 2014 12:15:06 -0500 Subject: asyncio: Add support for UNIX Domain Sockets. --- Lib/asyncio/base_events.py | 7 + Lib/asyncio/events.py | 26 ++ Lib/asyncio/streams.py | 39 ++- Lib/asyncio/test_utils.py | 153 ++++++++--- Lib/asyncio/unix_events.py | 75 +++++- Lib/test/test_asyncio/test_base_events.py | 2 +- Lib/test/test_asyncio/test_events.py | 349 +++++++++++++++++--------- Lib/test/test_asyncio/test_selector_events.py | 3 +- Lib/test/test_asyncio/test_streams.py | 195 +++++++++++--- Lib/test/test_asyncio/test_unix_events.py | 82 +++++- 10 files changed, 738 insertions(+), 193 deletions(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 3bbf6b5..b74e936 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop): sock.setblocking(False) + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @tasks.coroutine + def _create_connection_transport(self, sock, protocol_factory, ssl, + server_hostname): protocol = protocol_factory() waiter = futures.Future(loop=self) if ssl: diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index dd9e3fb..7841ad9 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -220,6 +220,32 @@ class AbstractEventLoop: """ raise NotImplementedError + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError + + def create_unix_server(self, protocol_factory, path, *, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. + + The return valud is a Server object, which can be used to stop + the service. + + path is a str, representing a file systsem path to bind the + server socket to. + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError + def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0): diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 8fc2147..698c5c6 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,9 +1,13 @@ """Stream-related things.""" __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server', 'IncompleteReadError', + 'open_connection', 'start_server', + 'open_unix_connection', 'start_unix_server', + 'IncompleteReadError', ] +import socket + from . import events from . import futures from . import protocols @@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *, return (yield from loop.create_server(factory, host, port, **kwds)) +if hasattr(socket, 'AF_UNIX'): + # UNIX Domain Sockets are supported on this platform + + @tasks.coroutine + def open_unix_connection(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `open_connection` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + + @tasks.coroutine + def start_unix_server(client_connected_cb, path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `start_server` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_unix_server(factory, path, **kwds)) + + class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py index deab7c3..de2916b 100644 --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -4,12 +4,18 @@ import collections import contextlib import io import os +import socket +import socketserver import sys +import tempfile import threading import time import unittest import unittest.mock + +from http.server import HTTPServer from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + try: import ssl except ImportError: # pragma: no cover @@ -70,42 +76,51 @@ def run_once(loop): loop.run_forever() -@contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): +class SilentWSGIRequestHandler(WSGIRequestHandler): - class SilentWSGIRequestHandler(WSGIRequestHandler): - def get_stderr(self): - return io.StringIO() + def get_stderr(self): + return io.StringIO() - def log_message(self, format, *args): - pass + def log_message(self, format, *args): + pass - class SilentWSGIServer(WSGIServer): - def handle_error(self, request, client_address): + +class SilentWSGIServer(WSGIServer): + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer pass - class SSLWSGIServer(SilentWSGIServer): - def finish_request(self, request, client_address): - # The relative location of our test directory (which - # contains the ssl key and certificate files) differs - # between the stdlib and stand-alone asyncio. - # Prefer our own if we can find it. - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - if not os.path.isdir(here): - here = os.path.join(os.path.dirname(os.__file__), - 'test', 'test_asyncio') - keyfile = os.path.join(here, 'ssl_key.pem') - certfile = os.path.join(here, 'ssl_cert.pem') - ssock = ssl.wrap_socket(request, - keyfile=keyfile, - certfile=certfile, - server_side=True) - try: - self.RequestHandlerClass(ssock, client_address, self) - ssock.close() - except OSError: - # maybe socket has been closed by peer - pass + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): def app(environ, start_response): status = '200 OK' @@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): # Run the test WSGI server in a separate thread in order not to # interfere with event handling in the main thread - server_class = SSLWSGIServer if use_ssl else SilentWSGIServer - httpd = make_server(host, port, app, - server_class, SilentWSGIRequestHandler) + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) httpd.address = httpd.server_address server_thread = threading.Thread(target=httpd.serve_forever) server_thread.start() @@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): server_thread.join() +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + with tempfile.NamedTemporaryFile() as file: + return file.name + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + def make_test_protocol(base): dct = {} for name in dir(base): @@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop): def _write_to_self(self): pass + def MockCallback(**kwargs): return unittest.mock.Mock(spec=['__call__'], **kwargs) diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index ea79d33..e0d7507 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -11,6 +11,7 @@ import sys import threading +from . import base_events from . import base_subprocess from . import constants from . import events @@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): - """Unix event loop + """Unix event loop. - Adds signal handling to SelectorEventLoop + Adds signal handling and UNIX Domain Socket support to SelectorEventLoop. """ def __init__(self, selector=None): @@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) + @tasks.coroutine + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + assert server_hostname is None or isinstance(server_hostname, str) + if ssl: + if server_hostname is None: + raise ValueError( + 'you have to pass server_hostname when using ssl') + else: + if server_hostname is not None: + raise ValueError('server_hostname is only meaningful with ssl') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) + sock.setblocking(False) + yield from self.sock_connect(sock, path) + except OSError: + if sock is not None: + sock.close() + raise + + else: + if sock is None: + raise ValueError('no path and sock were specified') + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @tasks.coroutine + def create_unix_server(self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None): + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + + if path is not None: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + try: + sock.bind(path) + except OSError as exc: + if exc.errno == errno.EADDRINUSE: + # Let's improve the error message by adding + # with what exact address it occurs. + msg = 'Address {!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) from None + else: + raise + else: + if sock is None: + raise ValueError( + 'path was not specified, and no sock specified') + + if sock.family != socket.AF_UNIX: + raise ValueError( + 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + + server = base_events.Server(self, [sock]) + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + def _set_nonblocking(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFL) diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 94e2d59..9fa9841 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase): idx = -1 data = [10.0, 10.0, 10.3, 13.0] - self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())] + self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())] self.loop._run_once() self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 8c32a6e..c9d04c0 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -39,13 +39,14 @@ def data_file(filename): return fullname raise FileNotFoundError(filename) + ONLYCERT = data_file('ssl_cert.pem') ONLYKEY = data_file('ssl_key.pem') SIGNED_CERTFILE = data_file('keycert3.pem') SIGNING_CA = data_file('pycacert.pem') -class MyProto(asyncio.Protocol): +class MyBaseProto(asyncio.Protocol): done = None def __init__(self, loop=None): @@ -59,7 +60,6 @@ class MyProto(asyncio.Protocol): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' - transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') def data_received(self, data): assert self.state == 'CONNECTED', self.state @@ -76,6 +76,12 @@ class MyProto(asyncio.Protocol): self.done.set_result(None) +class MyProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + class MyDatagramProto(asyncio.DatagramProtocol): done = None @@ -357,22 +363,30 @@ class EventLoopTestsMixin: r.close() self.assertGreaterEqual(len(data), 200) + def _basetest_sock_client_ops(self, httpd, sock): + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + def test_sock_client_ops(self): with test_utils.run_test_server() as httpd: sock = socket.socket() - sock.setblocking(False) - self.loop.run_until_complete( - self.loop.sock_connect(sock, httpd.address)) - self.loop.run_until_complete( - self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - data = self.loop.run_until_complete( - self.loop.sock_recv(sock, 1024)) - # consume data - self.loop.run_until_complete( - self.loop.sock_recv(sock, 1024)) - sock.close() + self._basetest_sock_client_ops(httpd, sock) - self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_unix_sock_client_ops(self): + with test_utils.run_test_unix_server() as httpd: + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_client_ops(httpd, sock) def test_sock_client_fail(self): # Make sure that we will get an unused port @@ -485,16 +499,26 @@ class EventLoopTestsMixin: self.loop.run_forever() self.assertEqual(caught, 1) + def _basetest_create_connection(self, connection_fut): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + def test_create_connection(self): with test_utils.run_test_server() as httpd: - f = self.loop.create_connection( + conn_fut = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address) - tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, asyncio.Transport) - self.assertIsInstance(pr, asyncio.Protocol) - self.loop.run_until_complete(pr.done) - self.assertGreater(pr.nbytes, 0) - tr.close() + self._basetest_create_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), httpd.address) + self._basetest_create_connection(conn_fut) def test_create_connection_sock(self): with test_utils.run_test_server() as httpd: @@ -524,20 +548,37 @@ class EventLoopTestsMixin: self.assertGreater(pr.nbytes, 0) tr.close() + def _basetest_create_ssl_connection(self, connection_fut): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: - f = self.loop.create_connection( - lambda: MyProto(loop=self.loop), *httpd.address, + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, ssl=test_utils.dummy_ssl_context()) - tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, asyncio.Transport) - self.assertIsInstance(pr, asyncio.Protocol) - self.assertTrue('ssl' in tr.__class__.__name__.lower()) - self.assertIsNotNone(tr.get_extra_info('sockname')) - self.loop.run_until_complete(pr.done) - self.assertGreater(pr.nbytes, 0) - tr.close() + + self._basetest_create_ssl_connection(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_ssl_unix_connection(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='127.0.0.1') + + self._basetest_create_ssl_connection(conn_fut) def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: @@ -561,14 +602,8 @@ class EventLoopTestsMixin: self.assertIn(str(httpd.address), cm.exception.strerror) def test_create_server(self): - proto = None - - def factory(): - nonlocal proto - proto = MyProto() - return proto - - f = self.loop.create_server(factory, '0.0.0.0', 0) + proto = MyProto() + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) server = self.loop.run_until_complete(f) self.assertEqual(len(server.sockets), 1) sock = server.sockets[0] @@ -605,38 +640,76 @@ class EventLoopTestsMixin: # close server server.close() - def _make_ssl_server(self, factory, certfile, keyfile=None): + def _make_unix_server(self, factory, **kwargs): + path = test_utils.gen_unix_socket_path() + self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) + + f = self.loop.create_unix_server(factory, path, **kwargs) + server = self.loop.run_until_complete(f) + + return server, path + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server(self): + proto = MyProto() + server, path = self._make_unix_server(lambda: proto) + self.assertEqual(len(server.sockets), 1) + + client = socket.socket(socket.AF_UNIX) + client.connect(path) + client.sendall(b'xxx') + test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto is not None, 10) + + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def _create_ssl_context(self, certfile, keyfile=None): sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext - f = self.loop.create_server( - factory, '127.0.0.1', 0, ssl=sslcontext) + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext) server = self.loop.run_until_complete(f) + sock = server.sockets[0] host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') return server, host, port + def _make_ssl_unix_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + return self._make_unix_server(factory, ssl=sslcontext) + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl(self): - proto = None - - class ClientMyProto(MyProto): - def connection_made(self, transport): - self.transport = transport - assert self.state == 'INITIAL', self.state - self.state = 'CONNECTED' + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, ONLYCERT, ONLYKEY) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto - - server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY) - - f_c = self.loop.create_connection(ClientMyProto, host, port, + f_c = self.loop.create_connection(MyBaseProto, host, port, ssl=test_utils.dummy_ssl_context()) client, pr = self.loop.run_until_complete(f_c) @@ -667,16 +740,45 @@ class EventLoopTestsMixin: server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_verify_failed(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, ONLYCERT, ONLYKEY) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + f_c = self.loop.create_unix_connection( + MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), + server_hostname='') + + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -697,15 +799,36 @@ class EventLoopTestsMixin: @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_match_failed(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # no CA loaded + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='invalid') + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_match_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -729,15 +852,36 @@ class EventLoopTestsMixin: @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_verified(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -915,19 +1059,15 @@ class EventLoopTestsMixin: @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_read_pipe(self): - proto = None - - def factory(): - nonlocal proto - proto = MyReadPipeProto(loop=self.loop) - return proto + proto = MyReadPipeProto(loop=self.loop) rpipe, wpipe = os.pipe() pipeobj = io.open(rpipe, 'rb', 1024) @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + t, p = yield from self.loop.connect_read_pipe( + lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) @@ -959,19 +1099,14 @@ class EventLoopTestsMixin: # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 @support.requires_freebsd_version(8) def test_read_pty_output(self): - proto = None - - def factory(): - nonlocal proto - proto = MyReadPipeProto(loop=self.loop) - return proto + proto = MyReadPipeProto(loop=self.loop) master, slave = os.openpty() master_read_obj = io.open(master, 'rb', 0) @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe(factory, + t, p = yield from self.loop.connect_read_pipe(lambda: proto, master_read_obj) self.assertIs(p, proto) self.assertIs(t, proto.transport) @@ -999,21 +1134,17 @@ class EventLoopTestsMixin: @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + t, p = yield from self.loop.connect_write_pipe( + lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual('CONNECTED', proto.state) @@ -1045,21 +1176,16 @@ class EventLoopTestsMixin: @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe_disconnect_on_close(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - rsock, wsock = test_utils.socketpair() pipeobj = io.open(wsock.detach(), 'wb', 1024) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, + t, p = yield from self.loop.connect_write_pipe(lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) @@ -1084,21 +1210,16 @@ class EventLoopTestsMixin: # older than 10.6 (Snow Leopard) @support.requires_mac_ver(10, 6) def test_write_pty(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - master, slave = os.openpty() slave_write_obj = io.open(slave, 'wb', 0) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, + t, p = yield from self.loop.connect_write_pipe(lambda: proto, slave_write_obj) self.assertIs(p, proto) self.assertIs(t, proto.transport) diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 855a895..7741e19 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase): self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock() waiter = asyncio.Future(loop=self.loop) - transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter) + transport = self.loop._make_ssl_transport( + m, asyncio.Protocol(), m, waiter) self.assertIsInstance(transport, _SelectorSslTransport) @unittest.mock.patch('asyncio.selector_events.ssl', None) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index ee3fb45..31e26a6 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1,6 +1,8 @@ """Tests for streams.py.""" +import functools import gc +import socket import unittest import unittest.mock try: @@ -32,48 +34,85 @@ class StreamReaderTests(unittest.TestCase): stream = asyncio.StreamReader() self.assertIs(stream._loop, m_events.get_event_loop.return_value) + def _basetest_open_connection(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + writer.close() + def test_open_connection(self): with test_utils.run_test_server() as httpd: - f = asyncio.open_connection(*httpd.address, loop=self.loop) - reader, writer = self.loop.run_until_complete(f) - writer.write(b'GET / HTTP/1.0\r\n\r\n') - f = reader.readline() - data = self.loop.run_until_complete(f) - self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - f = reader.read() - data = self.loop.run_until_complete(f) - self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - - writer.close() + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): + try: + reader, writer = self.loop.run_until_complete(open_connection_fut) + finally: + asyncio.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: - try: - asyncio.set_event_loop(self.loop) - f = asyncio.open_connection(*httpd.address, - ssl=test_utils.dummy_ssl_context()) - reader, writer = self.loop.run_until_complete(f) - finally: - asyncio.set_event_loop(None) - writer.write(b'GET / HTTP/1.0\r\n\r\n') - f = reader.read() - data = self.loop.run_until_complete(f) - self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + conn_fut = asyncio.open_connection( + *httpd.address, + ssl=test_utils.dummy_ssl_context(), + loop=self.loop) - writer.close() + self._basetest_open_connection_no_loop_ssl(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_no_loop_ssl(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = asyncio.open_unix_connection( + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='', + loop=self.loop) + + self._basetest_open_connection_no_loop_ssl(conn_fut) + + def _basetest_open_connection_error(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + writer.close() + test_utils.run_briefly(self.loop) def test_open_connection_error(self): with test_utils.run_test_server() as httpd: - f = asyncio.open_connection(*httpd.address, loop=self.loop) - reader, writer = self.loop.run_until_complete(f) - writer._protocol.connection_lost(ZeroDivisionError()) - f = reader.read() - with self.assertRaises(ZeroDivisionError): - self.loop.run_until_complete(f) + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) - writer.close() - test_utils.run_briefly(self.loop) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_error(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) def test_feed_empty_data(self): stream = asyncio.StreamReader(loop=self.loop) @@ -415,10 +454,13 @@ class StreamReaderTests(unittest.TestCase): client_writer.write(data) def start(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, - '127.0.0.1', 12345, + sock=sock, loop=self.loop)) + return sock.getsockname() def handle_client_callback(self, client_reader, client_writer): task = asyncio.Task(client_reader.readline(), loop=self.loop) @@ -429,10 +471,15 @@ class StreamReaderTests(unittest.TestCase): task.add_done_callback(done) def start_callback(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) + addr = sock.getsockname() + sock.close() self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client_callback, - '127.0.0.1', 12345, + host=addr[0], port=addr[1], loop=self.loop)) + return addr def stop(self): if self.server is not None: @@ -441,9 +488,9 @@ class StreamReaderTests(unittest.TestCase): self.server = None @asyncio.coroutine - def client(): + def client(addr): reader, writer = yield from asyncio.open_connection( - '127.0.0.1', 12345, loop=self.loop) + *addr, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -453,20 +500,90 @@ class StreamReaderTests(unittest.TestCase): # test the server variant with a coroutine as client handler server = MyServer(self.loop) - server.start() - msg = self.loop.run_until_complete(asyncio.Task(client(), + addr = server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") # test the server variant with a callback as client handler server = MyServer(self.loop) - server.start_callback() - msg = self.loop.run_until_complete(asyncio.Task(client(), + addr = server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_start_unix_server(self): + + class MyServer: + + def __init__(self, loop, path): + self.server = None + self.loop = loop + self.path = path + + @asyncio.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client, + path=self.path, + loop=self.loop)) + + def handle_client_callback(self, client_reader, client_writer): + task = asyncio.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client_callback, + path=self.path, + loop=self.loop)) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @asyncio.coroutine + def client(path): + reader, writer = yield from asyncio.open_unix_connection( + path, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py index 9461ec8..2fa1db4 100644 --- a/Lib/test/test_asyncio/test_unix_events.py +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -7,8 +7,10 @@ import io import os import pprint import signal +import socket import stat import sys +import tempfile import threading import unittest import unittest.mock @@ -24,7 +26,7 @@ from asyncio import unix_events @unittest.skipUnless(signal, 'Signals are not supported') -class SelectorEventLoopTests(unittest.TestCase): +class SelectorEventLoopSignalTests(unittest.TestCase): def setUp(self): self.loop = asyncio.SelectorEventLoop() @@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase): m_signal.set_wakeup_fd.assert_called_once_with(-1) +@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') +class SelectorEventLoopUnixSocketTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_create_unix_server_existing_path_sock(self): + with test_utils.unix_socket_path() as path: + sock = socket.socket(socket.AF_UNIX) + sock.bind(path) + + coro = self.loop.create_unix_server(lambda: None, path) + with self.assertRaisesRegexp(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_existing_path_nonsock(self): + with tempfile.NamedTemporaryFile() as file: + coro = self.loop.create_unix_server(lambda: None, file.name) + with self.assertRaisesRegexp(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_ssl_bool(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl=True) + with self.assertRaisesRegex(TypeError, + 'ssl argument must be an SSLContext'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_nopath_nosock(self): + coro = self.loop.create_unix_server(lambda: None, path=None) + with self.assertRaisesRegex(ValueError, + 'path was not specified, and no sock'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_path_inetsock(self): + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=socket.socket()) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Socket was expected'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_path_sock(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', sock=object()) + with self.assertRaisesRegex(ValueError, 'path and sock can not be'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nopath_nosock(self): + coro = self.loop.create_unix_connection( + lambda: None, None) + with self.assertRaisesRegex(ValueError, + 'no path and sock were specified'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nossl_serverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', server_hostname='spam') + with self.assertRaisesRegex(ValueError, + 'server_hostname is only meaningful'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_ssl_noserverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', ssl=True) + + with self.assertRaisesRegexp( + ValueError, 'you have to pass server_hostname when using ssl'): + + self.loop.run_until_complete(coro) + + class UnixReadPipeTransportTests(unittest.TestCase): def setUp(self): -- cgit v0.12