diff options
author | Yury Selivanov <yury@magic.io> | 2016-11-09 20:47:00 (GMT) |
---|---|---|
committer | Yury Selivanov <yury@magic.io> | 2016-11-09 20:47:00 (GMT) |
commit | a1a8b7d3d7f628aec31be364c77cbb3e21cdbc0b (patch) | |
tree | f1c1fa5ae61d965d131a437904bd1712c7a6f017 | |
parent | d2fd3599abeed393ccdf4ee5cf1c7b346ba4a022 (diff) | |
download | cpython-a1a8b7d3d7f628aec31be364c77cbb3e21cdbc0b.zip cpython-a1a8b7d3d7f628aec31be364c77cbb3e21cdbc0b.tar.gz cpython-a1a8b7d3d7f628aec31be364c77cbb3e21cdbc0b.tar.bz2 |
Issue #28652: Make loop methods reject socket kinds they do not support.
-rw-r--r-- | Lib/asyncio/base_events.py | 56 | ||||
-rw-r--r-- | Lib/asyncio/unix_events.py | 4 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_base_events.py | 63 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 11 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_unix_events.py | 27 | ||||
-rw-r--r-- | Misc/NEWS | 2 |
6 files changed, 139 insertions, 24 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 6488f23..aa78367 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -84,12 +84,26 @@ def _set_reuseport(sock): 'SO_REUSEPORT defined but not implemented.') -# Linux's sock.type is a bitmask that can include extra info about socket. -_SOCKET_TYPE_MASK = 0 -if hasattr(socket, 'SOCK_NONBLOCK'): - _SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK -if hasattr(socket, 'SOCK_CLOEXEC'): - _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC +def _is_stream_socket(sock): + # Linux's socket.type is a bitmask that can include extra info + # about socket, therefore we can't do simple + # `sock_type == socket.SOCK_STREAM`. + return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM + + +def _is_dgram_socket(sock): + # Linux's socket.type is a bitmask that can include extra info + # about socket, therefore we can't do simple + # `sock_type == socket.SOCK_DGRAM`. + return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM + + +def _is_ip_socket(sock): + if sock.family == socket.AF_INET: + return True + if hasattr(socket, 'AF_INET6') and sock.family == socket.AF_INET6: + return True + return False def _ipaddr_info(host, port, family, type, proto): @@ -102,8 +116,12 @@ def _ipaddr_info(host, port, family, type, proto): host is None: return None - type &= ~_SOCKET_TYPE_MASK if type == socket.SOCK_STREAM: + # Linux only: + # getaddrinfo() can raise when socket.type is a bit mask. + # So if socket.type is a bit mask of SOCK_STREAM, and say + # SOCK_NONBLOCK, we simply return None, which will trigger + # a call to getaddrinfo() letting it process this request. proto = socket.IPPROTO_TCP elif type == socket.SOCK_DGRAM: proto = socket.IPPROTO_UDP @@ -124,7 +142,9 @@ def _ipaddr_info(host, port, family, type, proto): return None if family == socket.AF_UNSPEC: - afs = [socket.AF_INET, socket.AF_INET6] + afs = [socket.AF_INET] + if hasattr(socket, 'AF_INET6'): + afs.append(socket.AF_INET6) else: afs = [family] @@ -771,9 +791,13 @@ class BaseEventLoop(events.AbstractEventLoop): raise OSError('Multiple exceptions: {}'.format( ', '.join(str(exc) for exc in exceptions))) - elif sock is None: - raise ValueError( - 'host and port was not specified and no sock specified') + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + if not _is_stream_socket(sock) or not _is_ip_socket(sock): + raise ValueError( + 'A TCP Stream Socket was expected, got {!r}'.format(sock)) transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) @@ -817,6 +841,9 @@ class BaseEventLoop(events.AbstractEventLoop): allow_broadcast=None, sock=None): """Create datagram connection.""" if sock is not None: + if not _is_dgram_socket(sock): + raise ValueError( + 'A UDP Socket was expected, got {!r}'.format(sock)) if (local_addr or remote_addr or family or proto or flags or reuse_address or reuse_port or allow_broadcast): @@ -1027,6 +1054,9 @@ class BaseEventLoop(events.AbstractEventLoop): else: if sock is None: raise ValueError('Neither host/port nor sock were specified') + if not _is_stream_socket(sock) or not _is_ip_socket(sock): + raise ValueError( + 'A TCP Stream Socket was expected, got {!r}'.format(sock)) sockets = [sock] server = Server(self, sockets) @@ -1048,6 +1078,10 @@ class BaseEventLoop(events.AbstractEventLoop): This method is a coroutine. When completed, the coroutine returns a (transport, protocol) pair. """ + if not _is_stream_socket(sock): + raise ValueError( + 'A Stream Socket was expected, got {!r}'.format(sock)) + transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True) if self._debug: diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 65b61db..788a5a0 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -235,7 +235,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): if sock is None: raise ValueError('no path and sock were specified') if (sock.family != socket.AF_UNIX or - sock.type != socket.SOCK_STREAM): + not base_events._is_stream_socket(sock)): raise ValueError( 'A UNIX Domain Stream Socket was expected, got {!r}' .format(sock)) @@ -289,7 +289,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): 'path was not specified, and no sock specified') if (sock.family != socket.AF_UNIX or - sock.type != socket.SOCK_STREAM): + not base_events._is_stream_socket(sock)): raise ValueError( 'A UNIX Domain Stream Socket was expected, got {!r}' .format(sock)) diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index cdbd587..2a93923 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -116,6 +116,13 @@ class BaseEventTests(test_utils.TestCase): self.assertIsNone( base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP)) + if hasattr(socket, 'SOCK_NONBLOCK'): + self.assertEqual( + None, + base_events._ipaddr_info( + '1.2.3.4', 1, INET, STREAM | socket.SOCK_NONBLOCK, TCP)) + + def test_port_parameter_types(self): # Test obscure kinds of arguments for "port". INET = socket.AF_INET @@ -1040,6 +1047,43 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): MyProto, 'example.com', 80, sock=object()) self.assertRaises(ValueError, self.loop.run_until_complete, coro) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets') + def test_create_connection_wrong_sock(self): + sock = socket.socket(socket.AF_UNIX) + with sock: + coro = self.loop.create_connection(MyProto, sock=sock) + with self.assertRaisesRegex(ValueError, + 'A TCP Stream Socket was expected'): + self.loop.run_until_complete(coro) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets') + def test_create_server_wrong_sock(self): + sock = socket.socket(socket.AF_UNIX) + with sock: + coro = self.loop.create_server(MyProto, sock=sock) + with self.assertRaisesRegex(ValueError, + 'A TCP Stream Socket was expected'): + self.loop.run_until_complete(coro) + + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), + 'no socket.SOCK_NONBLOCK (linux only)') + def test_create_server_stream_bittype(self): + sock = socket.socket( + socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) + with sock: + coro = self.loop.create_server(lambda: None, sock=sock) + srv = self.loop.run_until_complete(coro) + srv.close() + self.loop.run_until_complete(srv.wait_closed()) + + def test_create_datagram_endpoint_wrong_sock(self): + sock = socket.socket(socket.AF_INET) + with sock: + coro = self.loop.create_datagram_endpoint(MyProto, sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UDP Socket was expected'): + self.loop.run_until_complete(coro) + def test_create_connection_no_host_port_sock(self): coro = self.loop.create_connection(MyProto) self.assertRaises(ValueError, self.loop.run_until_complete, coro) @@ -1487,36 +1531,39 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): self.assertEqual('CLOSED', protocol.state) def test_create_datagram_endpoint_sock_sockopts(self): + class FakeSock: + type = socket.SOCK_DGRAM + fut = self.loop.create_datagram_endpoint( - MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object()) + MyDatagramProto, local_addr=('127.0.0.1', 0), sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) fut = self.loop.create_datagram_endpoint( - MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object()) + MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) fut = self.loop.create_datagram_endpoint( - MyDatagramProto, family=1, sock=object()) + MyDatagramProto, family=1, sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) fut = self.loop.create_datagram_endpoint( - MyDatagramProto, proto=1, sock=object()) + MyDatagramProto, proto=1, sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) fut = self.loop.create_datagram_endpoint( - MyDatagramProto, flags=1, sock=object()) + MyDatagramProto, flags=1, sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) fut = self.loop.create_datagram_endpoint( - MyDatagramProto, reuse_address=True, sock=object()) + MyDatagramProto, reuse_address=True, sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) fut = self.loop.create_datagram_endpoint( - MyDatagramProto, reuse_port=True, sock=object()) + MyDatagramProto, reuse_port=True, sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) fut = self.loop.create_datagram_endpoint( - MyDatagramProto, allow_broadcast=True, sock=object()) + MyDatagramProto, allow_broadcast=True, sock=FakeSock()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) def test_create_datagram_endpoint_sockopts(self): diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 5b32332..28d92a9 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -791,9 +791,9 @@ class EventLoopTestsMixin: conn, _ = lsock.accept() proto = MyProto(loop=loop) proto.loop = loop - f = loop.create_task( + loop.run_until_complete( loop.connect_accepted_socket( - (lambda : proto), conn, ssl=server_ssl)) + (lambda: proto), conn, ssl=server_ssl)) loop.run_forever() proto.transport.close() lsock.close() @@ -1377,6 +1377,11 @@ class EventLoopTestsMixin: server.transport.close() def test_create_datagram_endpoint_sock(self): + if (sys.platform == 'win32' and + isinstance(self.loop, proactor_events.BaseProactorEventLoop)): + raise unittest.SkipTest( + 'UDP is not supported with proactor event loops') + sock = None local_address = ('127.0.0.1', 0) infos = self.loop.run_until_complete( @@ -1394,7 +1399,7 @@ class EventLoopTestsMixin: else: assert False, 'Can not create socket.' - f = self.loop.create_connection( + f = self.loop.create_datagram_endpoint( lambda: MyDatagramProto(loop=self.loop), sock=sock) tr, pr = self.loop.run_until_complete(f) self.assertIsInstance(tr, asyncio.Transport) diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py index 83a035e..89c6eed 100644 --- a/Lib/test/test_asyncio/test_unix_events.py +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -280,6 +280,33 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase): 'A UNIX Domain Stream.*was expected'): self.loop.run_until_complete(coro) + def test_create_unix_server_path_dgram(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + with sock: + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Stream.*was expected'): + self.loop.run_until_complete(coro) + + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), + 'no socket.SOCK_NONBLOCK (linux only)') + def test_create_unix_server_path_stream_bittype(self): + sock = socket.socket( + socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) + with tempfile.NamedTemporaryFile() as file: + fn = file.name + try: + with sock: + sock.bind(fn) + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + srv = self.loop.run_until_complete(coro) + srv.close() + self.loop.run_until_complete(srv.wait_closed()) + finally: + os.unlink(fn) + def test_create_unix_connection_path_inetsock(self): sock = socket.socket() with sock: @@ -455,6 +455,8 @@ Library - Issue #28639: Fix inspect.isawaitable to always return bool Patch by Justin Mayfield. +- Issue #28652: Make loop methods reject socket kinds they do not support. + IDLE ---- |