From d6c6771fc9682713ff2ebae2cd02ddbd2b48f657 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 15 Sep 2016 17:56:36 -0400 Subject: Issue #28176: Fix callbacks race in asyncio.SelectorLoop.sock_connect. --- Lib/asyncio/selector_events.py | 27 +++--- Lib/test/test_asyncio/test_selector_events.py | 124 +++++++++++++++++++++----- Misc/NEWS | 2 + 3 files changed, 113 insertions(+), 40 deletions(-) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index c18885e..2f02d76 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -400,6 +400,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): data = data[n:] self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + @coroutine def sock_connect(self, sock, address): """Connect to a remote socket at address. @@ -408,24 +409,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") - fut = self.create_future() - if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: - self._sock_connect(fut, sock, address) - else: + if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: resolved = base_events._ensure_resolved( address, family=sock.family, proto=sock.proto, loop=self) - resolved.add_done_callback( - lambda resolved: self._on_resolved(fut, sock, resolved)) - - return fut - - def _on_resolved(self, fut, sock, resolved): - try: + if not resolved.done(): + yield from resolved _, _, _, _, address = resolved.result()[0] - except Exception as exc: - fut.set_exception(exc) - else: - self._sock_connect(fut, sock, address) + + fut = self.create_future() + self._sock_connect(fut, sock, address) + return (yield from fut) def _sock_connect(self, fut, sock, address): fd = sock.fileno() @@ -436,8 +429,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): # connection runs in background. We have to wait until the socket # becomes writable to be notified when the connection succeed or # fails. - fut.add_done_callback(functools.partial(self._sock_connect_done, - fd)) + fut.add_done_callback( + functools.partial(self._sock_connect_done, fd)) self.add_writer(fd, self._sock_connect_cb, fut, sock, address) except Exception as exc: fut.set_exception(exc) diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 73bc3f3..0c26a87 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -2,6 +2,8 @@ import errno import socket +import threading +import time import unittest from unittest import mock try: @@ -337,18 +339,6 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): (10, self.loop._sock_sendall, f, True, sock, b'data'), self.loop.add_writer.call_args[0]) - def test_sock_connect(self): - sock = test_utils.mock_nonblocking_socket() - self.loop._sock_connect = mock.Mock() - - f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) - self.assertIsInstance(f, asyncio.Future) - self.loop._run_once() - future_in, sock_in, address_in = self.loop._sock_connect.call_args[0] - self.assertEqual(future_in, f) - self.assertEqual(sock_in, sock) - self.assertEqual(address_in, ('127.0.0.1', 8080)) - def test_sock_connect_timeout(self): # asyncio issue #205: sock_connect() must unregister the socket on # timeout error @@ -360,29 +350,34 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): sock.connect.side_effect = BlockingIOError # first call to sock_connect() registers the socket - fut = self.loop.sock_connect(sock, ('127.0.0.1', 80)) + fut = self.loop.create_task( + self.loop.sock_connect(sock, ('127.0.0.1', 80))) self.loop._run_once() self.assertTrue(sock.connect.called) self.assertTrue(self.loop.add_writer.called) - self.assertEqual(len(fut._callbacks), 1) # on timeout, the socket must be unregistered sock.connect.reset_mock() - fut.set_exception(asyncio.TimeoutError) - with self.assertRaises(asyncio.TimeoutError): + fut.cancel() + with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(fut) self.assertTrue(self.loop.remove_writer.called) - def test_sock_connect_resolve_using_socket_params(self): + @mock.patch('socket.getaddrinfo') + def test_sock_connect_resolve_using_socket_params(self, m_gai): addr = ('need-resolution.com', 8080) sock = test_utils.mock_nonblocking_socket() - self.loop.getaddrinfo = mock.Mock() - self.loop.sock_connect(sock, addr) - while not self.loop.getaddrinfo.called: + m_gai.side_effect = (None, None, None, None, ('127.0.0.1', 0)) + m_gai._is_coroutine = False + con = self.loop.create_task(self.loop.sock_connect(sock, addr)) + while not m_gai.called: self.loop._run_once() - self.loop.getaddrinfo.assert_called_with( - *addr, type=sock.type, family=sock.family, proto=sock.proto, - flags=0) + m_gai.assert_called_with( + addr[0], addr[1], sock.family, sock.type, sock.proto, 0) + + con.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(con) def test__sock_connect(self): f = asyncio.Future(loop=self.loop) @@ -1792,5 +1787,88 @@ class SelectorDatagramTransportTests(test_utils.TestCase): exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) +class SelectorLoopFunctionalTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @asyncio.coroutine + def recv_all(self, sock, nbytes): + buf = b'' + while len(buf) < nbytes: + buf += yield from self.loop.sock_recv(sock, nbytes - len(buf)) + return buf + + def test_sock_connect_sock_write_race(self): + TIMEOUT = 3.0 + PAYLOAD = b'DATA' * 1024 * 1024 + + class Server(threading.Thread): + def __init__(self, *args, srv_sock, **kwargs): + super().__init__(*args, **kwargs) + self.srv_sock = srv_sock + + def run(self): + with self.srv_sock: + srv_sock.listen(100) + + sock, addr = self.srv_sock.accept() + sock.settimeout(TIMEOUT) + + with sock: + sock.sendall(b'helo') + + buf = bytearray() + while len(buf) < len(PAYLOAD): + pack = sock.recv(1024 * 65) + if not pack: + break + buf.extend(pack) + + @asyncio.coroutine + def client(addr): + sock = socket.socket() + with sock: + sock.setblocking(False) + + started = time.monotonic() + while True: + if time.monotonic() - started > TIMEOUT: + self.fail('unable to connect to the socket') + return + try: + yield from self.loop.sock_connect(sock, addr) + except OSError: + yield from asyncio.sleep(0.05, loop=self.loop) + else: + break + + # Give 'Server' thread a chance to accept and send b'helo' + time.sleep(0.1) + + data = yield from self.recv_all(sock, 4) + self.assertEqual(data, b'helo') + yield from self.loop.sock_sendall(sock, PAYLOAD) + + srv_sock = socket.socket() + srv_sock.settimeout(TIMEOUT) + srv_sock.bind(('127.0.0.1', 0)) + srv_addr = srv_sock.getsockname() + + srv = Server(srv_sock=srv_sock, daemon=True) + srv.start() + + try: + self.loop.run_until_complete( + asyncio.wait_for(client(srv_addr), loop=self.loop, + timeout=TIMEOUT)) + finally: + srv.join() + + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS b/Misc/NEWS index be95974..8f5ee3b 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -275,6 +275,8 @@ Library - Issue #26909: Fix slow pipes IO in asyncio. Patch by INADA Naoki. +- Issue #28176: Fix callbacks race in asyncio.SelectorLoop.sock_connect. + IDLE ---- -- cgit v0.12