diff options
| author | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
|---|---|---|
| committer | Yury Selivanov <yselivanov@sprymix.com> | 2014-02-18 17:15:06 (GMT) |
| commit | 88a5bf0b2e2a55d8418132001a611af9c0419665 (patch) | |
| tree | 03841a088e2f8c04c8182999944711f4db039053 /Lib/test/test_asyncio/test_events.py | |
| parent | c36e504c53bb20ee6880b78d77aa1378519c3743 (diff) | |
| download | cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.zip cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.gz cpython-88a5bf0b2e2a55d8418132001a611af9c0419665.tar.bz2 | |
asyncio: Add support for UNIX Domain Sockets.
Diffstat (limited to 'Lib/test/test_asyncio/test_events.py')
| -rw-r--r-- | Lib/test/test_asyncio/test_events.py | 349 |
1 files changed, 235 insertions, 114 deletions
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 8c32a6e..c9d04c0 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -39,13 +39,14 @@ def data_file(filename): return fullname raise FileNotFoundError(filename) + ONLYCERT = data_file('ssl_cert.pem') ONLYKEY = data_file('ssl_key.pem') SIGNED_CERTFILE = data_file('keycert3.pem') SIGNING_CA = data_file('pycacert.pem') -class MyProto(asyncio.Protocol): +class MyBaseProto(asyncio.Protocol): done = None def __init__(self, loop=None): @@ -59,7 +60,6 @@ class MyProto(asyncio.Protocol): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' - transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') def data_received(self, data): assert self.state == 'CONNECTED', self.state @@ -76,6 +76,12 @@ class MyProto(asyncio.Protocol): self.done.set_result(None) +class MyProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + class MyDatagramProto(asyncio.DatagramProtocol): done = None @@ -357,22 +363,30 @@ class EventLoopTestsMixin: r.close() self.assertGreaterEqual(len(data), 200) + def _basetest_sock_client_ops(self, httpd, sock): + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + def test_sock_client_ops(self): with test_utils.run_test_server() as httpd: sock = socket.socket() - sock.setblocking(False) - self.loop.run_until_complete( - self.loop.sock_connect(sock, httpd.address)) - self.loop.run_until_complete( - self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - data = self.loop.run_until_complete( - self.loop.sock_recv(sock, 1024)) - # consume data - self.loop.run_until_complete( - self.loop.sock_recv(sock, 1024)) - sock.close() + self._basetest_sock_client_ops(httpd, sock) - self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_unix_sock_client_ops(self): + with test_utils.run_test_unix_server() as httpd: + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_client_ops(httpd, sock) def test_sock_client_fail(self): # Make sure that we will get an unused port @@ -485,16 +499,26 @@ class EventLoopTestsMixin: self.loop.run_forever() self.assertEqual(caught, 1) + def _basetest_create_connection(self, connection_fut): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + def test_create_connection(self): with test_utils.run_test_server() as httpd: - f = self.loop.create_connection( + conn_fut = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address) - tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, asyncio.Transport) - self.assertIsInstance(pr, asyncio.Protocol) - self.loop.run_until_complete(pr.done) - self.assertGreater(pr.nbytes, 0) - tr.close() + self._basetest_create_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), httpd.address) + self._basetest_create_connection(conn_fut) def test_create_connection_sock(self): with test_utils.run_test_server() as httpd: @@ -524,20 +548,37 @@ class EventLoopTestsMixin: self.assertGreater(pr.nbytes, 0) tr.close() + def _basetest_create_ssl_connection(self, connection_fut): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: - f = self.loop.create_connection( - lambda: MyProto(loop=self.loop), *httpd.address, + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, ssl=test_utils.dummy_ssl_context()) - tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, asyncio.Transport) - self.assertIsInstance(pr, asyncio.Protocol) - self.assertTrue('ssl' in tr.__class__.__name__.lower()) - self.assertIsNotNone(tr.get_extra_info('sockname')) - self.loop.run_until_complete(pr.done) - self.assertGreater(pr.nbytes, 0) - tr.close() + + self._basetest_create_ssl_connection(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_ssl_unix_connection(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='127.0.0.1') + + self._basetest_create_ssl_connection(conn_fut) def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: @@ -561,14 +602,8 @@ class EventLoopTestsMixin: self.assertIn(str(httpd.address), cm.exception.strerror) def test_create_server(self): - proto = None - - def factory(): - nonlocal proto - proto = MyProto() - return proto - - f = self.loop.create_server(factory, '0.0.0.0', 0) + proto = MyProto() + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) server = self.loop.run_until_complete(f) self.assertEqual(len(server.sockets), 1) sock = server.sockets[0] @@ -605,38 +640,76 @@ class EventLoopTestsMixin: # close server server.close() - def _make_ssl_server(self, factory, certfile, keyfile=None): + def _make_unix_server(self, factory, **kwargs): + path = test_utils.gen_unix_socket_path() + self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) + + f = self.loop.create_unix_server(factory, path, **kwargs) + server = self.loop.run_until_complete(f) + + return server, path + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server(self): + proto = MyProto() + server, path = self._make_unix_server(lambda: proto) + self.assertEqual(len(server.sockets), 1) + + client = socket.socket(socket.AF_UNIX) + client.connect(path) + client.sendall(b'xxx') + test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto is not None, 10) + + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def _create_ssl_context(self, certfile, keyfile=None): sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext - f = self.loop.create_server( - factory, '127.0.0.1', 0, ssl=sslcontext) + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext) server = self.loop.run_until_complete(f) + sock = server.sockets[0] host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') return server, host, port + def _make_ssl_unix_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + return self._make_unix_server(factory, ssl=sslcontext) + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl(self): - proto = None - - class ClientMyProto(MyProto): - def connection_made(self, transport): - self.transport = transport - assert self.state == 'INITIAL', self.state - self.state = 'CONNECTED' + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, ONLYCERT, ONLYKEY) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto - - server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY) - - f_c = self.loop.create_connection(ClientMyProto, host, port, + f_c = self.loop.create_connection(MyBaseProto, host, port, ssl=test_utils.dummy_ssl_context()) client, pr = self.loop.run_until_complete(f_c) @@ -667,16 +740,45 @@ class EventLoopTestsMixin: server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_verify_failed(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, ONLYCERT, ONLYKEY) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + f_c = self.loop.create_unix_connection( + MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), + server_hostname='') + + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -697,15 +799,36 @@ class EventLoopTestsMixin: @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_match_failed(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # no CA loaded + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='invalid') + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_match_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -729,15 +852,36 @@ class EventLoopTestsMixin: @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_verified(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -915,19 +1059,15 @@ class EventLoopTestsMixin: @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_read_pipe(self): - proto = None - - def factory(): - nonlocal proto - proto = MyReadPipeProto(loop=self.loop) - return proto + proto = MyReadPipeProto(loop=self.loop) rpipe, wpipe = os.pipe() pipeobj = io.open(rpipe, 'rb', 1024) @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + t, p = yield from self.loop.connect_read_pipe( + lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) @@ -959,19 +1099,14 @@ class EventLoopTestsMixin: # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 @support.requires_freebsd_version(8) def test_read_pty_output(self): - proto = None - - def factory(): - nonlocal proto - proto = MyReadPipeProto(loop=self.loop) - return proto + proto = MyReadPipeProto(loop=self.loop) master, slave = os.openpty() master_read_obj = io.open(master, 'rb', 0) @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe(factory, + t, p = yield from self.loop.connect_read_pipe(lambda: proto, master_read_obj) self.assertIs(p, proto) self.assertIs(t, proto.transport) @@ -999,21 +1134,17 @@ class EventLoopTestsMixin: @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + t, p = yield from self.loop.connect_write_pipe( + lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual('CONNECTED', proto.state) @@ -1045,21 +1176,16 @@ class EventLoopTestsMixin: @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe_disconnect_on_close(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - rsock, wsock = test_utils.socketpair() pipeobj = io.open(wsock.detach(), 'wb', 1024) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, + t, p = yield from self.loop.connect_write_pipe(lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) @@ -1084,21 +1210,16 @@ class EventLoopTestsMixin: # older than 10.6 (Snow Leopard) @support.requires_mac_ver(10, 6) def test_write_pty(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - master, slave = os.openpty() slave_write_obj = io.open(slave, 'wb', 0) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, + t, p = yield from self.loop.connect_write_pipe(lambda: proto, slave_write_obj) self.assertIs(p, proto) self.assertIs(t, proto.transport) |
