diff options
author | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
---|---|---|
committer | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
commit | 88a5bf0b2e2a55d8418132001a611af9c0419665 (patch) | |
tree | 03841a088e2f8c04c8182999944711f4db039053 /Lib/test/test_asyncio/test_streams.py | |
parent | c36e504c53bb20ee6880b78d77aa1378519c3743 (diff) | |
download | cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.zip cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.gz cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.bz2 |
asyncio: Add support for UNIX Domain Sockets.
Diffstat (limited to 'Lib/test/test_asyncio/test_streams.py')
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 195 |
1 files changed, 156 insertions, 39 deletions
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index ee3fb45..31e26a6 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1,6 +1,8 @@ """Tests for streams.py.""" +import functools import gc +import socket import unittest import unittest.mock try: @@ -32,48 +34,85 @@ class StreamReaderTests(unittest.TestCase): stream = asyncio.StreamReader() self.assertIs(stream._loop, m_events.get_event_loop.return_value) + def _basetest_open_connection(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + writer.close() + def test_open_connection(self): with test_utils.run_test_server() as httpd: - f = asyncio.open_connection(*httpd.address, loop=self.loop) - reader, writer = self.loop.run_until_complete(f) - writer.write(b'GET / HTTP/1.0\r\n\r\n') - f = reader.readline() - data = self.loop.run_until_complete(f) - self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - f = reader.read() - data = self.loop.run_until_complete(f) - self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - - writer.close() + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): + try: + reader, writer = self.loop.run_until_complete(open_connection_fut) + finally: + asyncio.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: - try: - asyncio.set_event_loop(self.loop) - f = asyncio.open_connection(*httpd.address, - ssl=test_utils.dummy_ssl_context()) - reader, writer = self.loop.run_until_complete(f) - finally: - asyncio.set_event_loop(None) - writer.write(b'GET / HTTP/1.0\r\n\r\n') - f = reader.read() - data = self.loop.run_until_complete(f) - self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + conn_fut = asyncio.open_connection( + *httpd.address, + ssl=test_utils.dummy_ssl_context(), + loop=self.loop) - writer.close() + self._basetest_open_connection_no_loop_ssl(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_no_loop_ssl(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = asyncio.open_unix_connection( + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='', + loop=self.loop) + + self._basetest_open_connection_no_loop_ssl(conn_fut) + + def _basetest_open_connection_error(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + writer.close() + test_utils.run_briefly(self.loop) def test_open_connection_error(self): with test_utils.run_test_server() as httpd: - f = asyncio.open_connection(*httpd.address, loop=self.loop) - reader, writer = self.loop.run_until_complete(f) - writer._protocol.connection_lost(ZeroDivisionError()) - f = reader.read() - with self.assertRaises(ZeroDivisionError): - self.loop.run_until_complete(f) + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) - writer.close() - test_utils.run_briefly(self.loop) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_error(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) def test_feed_empty_data(self): stream = asyncio.StreamReader(loop=self.loop) @@ -415,10 +454,13 @@ class StreamReaderTests(unittest.TestCase): client_writer.write(data) def start(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, - '127.0.0.1', 12345, + sock=sock, loop=self.loop)) + return sock.getsockname() def handle_client_callback(self, client_reader, client_writer): task = asyncio.Task(client_reader.readline(), loop=self.loop) @@ -429,10 +471,15 @@ class StreamReaderTests(unittest.TestCase): task.add_done_callback(done) def start_callback(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) + addr = sock.getsockname() + sock.close() self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client_callback, - '127.0.0.1', 12345, + host=addr[0], port=addr[1], loop=self.loop)) + return addr def stop(self): if self.server is not None: @@ -441,9 +488,9 @@ class StreamReaderTests(unittest.TestCase): self.server = None @asyncio.coroutine - def client(): + def client(addr): reader, writer = yield from asyncio.open_connection( - '127.0.0.1', 12345, loop=self.loop) + *addr, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -453,20 +500,90 @@ class StreamReaderTests(unittest.TestCase): # test the server variant with a coroutine as client handler server = MyServer(self.loop) - server.start() - msg = self.loop.run_until_complete(asyncio.Task(client(), + addr = server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") # test the server variant with a callback as client handler server = MyServer(self.loop) - server.start_callback() - msg = self.loop.run_until_complete(asyncio.Task(client(), + addr = server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_start_unix_server(self): + + class MyServer: + + def __init__(self, loop, path): + self.server = None + self.loop = loop + self.path = path + + @asyncio.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client, + path=self.path, + loop=self.loop)) + + def handle_client_callback(self, client_reader, client_writer): + task = asyncio.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client_callback, + path=self.path, + loop=self.loop)) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @asyncio.coroutine + def client(path): + reader, writer = yield from asyncio.open_unix_connection( + path, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + if __name__ == '__main__': unittest.main() |