diff options
author | Victor Stinner <victor.stinner@gmail.com> | 2014-08-31 13:08:21 (GMT) |
---|---|---|
committer | Victor Stinner <victor.stinner@gmail.com> | 2014-08-31 13:08:21 (GMT) |
commit | de993bd9b68f1a1c2a3208e2024c94f99eb6cd05 (patch) | |
tree | ccbe623f58be09f2e834cce3b4ab86e07d135bb9 /Lib | |
parent | 41c13ce50ad07847ffb609fcabdb25a07b897db4 (diff) | |
parent | d5aeccf9767c1619faa29e8ed61c93bde7bc5e3f (diff) | |
download | cpython-de993bd9b68f1a1c2a3208e2024c94f99eb6cd05.zip cpython-de993bd9b68f1a1c2a3208e2024c94f99eb6cd05.tar.gz cpython-de993bd9b68f1a1c2a3208e2024c94f99eb6cd05.tar.bz2 |
(Merge 3.4) asyncio, Tulip issue 205: Fix a race condition in
BaseSelectorEventLoop.sock_connect()
There is a race condition in create_connection() used with wait_for() to have a
timeout. sock_connect() registers the file descriptor of the socket to be
notified of write event (if connect() raises BlockingIOError). When
create_connection() is cancelled with a TimeoutError, sock_connect() coroutine
gets the exception, but it doesn't unregister the file descriptor for write
event. create_connection() gets the TimeoutError and closes the socket.
If you call again create_connection(), the new socket will likely gets the same
file descriptor, which is still registered in the selector. When sock_connect()
calls add_writer(), it tries to modify the entry instead of creating a new one.
This issue was originally reported in the Trollius project, but the bug comes
from Tulip in fact (Trollius is based on Tulip):
https://bitbucket.org/enovance/trollius/issue/15/after-timeouterror-on-wait_for
This change fixes the race condition. It also makes sock_connect() more
reliable (and portable) is sock.connect() raises an InterruptedError.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/selector_events.py | 44 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_selector_events.py | 74 |
2 files changed, 83 insertions, 35 deletions
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 0434a70..33de92e 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -8,6 +8,7 @@ __all__ = ['BaseSelectorEventLoop'] import collections import errno +import functools import socket try: import ssl @@ -345,26 +346,43 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): except ValueError as err: fut.set_exception(err) else: - self._sock_connect(fut, False, sock, address) + self._sock_connect(fut, sock, address) return fut - def _sock_connect(self, fut, registered, sock, address): + def _sock_connect(self, fut, sock, address): fd = sock.fileno() - if registered: - self.remove_writer(fd) + try: + while True: + try: + sock.connect(address) + except InterruptedError: + continue + else: + break + except BlockingIOError: + fut.add_done_callback(functools.partial(self._sock_connect_done, + sock)) + self.add_writer(fd, self._sock_connect_cb, fut, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def _sock_connect_done(self, sock, fut): + self.remove_writer(sock.fileno()) + + def _sock_connect_cb(self, fut, sock, address): if fut.cancelled(): return + try: - if not registered: - # First time around. - sock.connect(address) - else: - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - # Jump to the except clause below. - raise OSError(err, 'Connect call failed %s' % (address,)) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to any except clause below. + raise OSError(err, 'Connect call failed %s' % (address,)) except (BlockingIOError, InterruptedError): - self.add_writer(fd, self._sock_connect, fut, True, sock, address) + # socket is still registered, the callback will be retried later + pass except Exception as exc: fut.set_exception(exc) else: diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index df6e991..528da39 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -40,8 +40,9 @@ def list_to_buffer(l=()): class BaseSelectorEventLoopTests(test_utils.TestCase): def setUp(self): - selector = mock.Mock() - self.loop = TestBaseSelectorEventLoop(selector) + self.selector = mock.Mock() + self.selector.select.return_value = [] + self.loop = TestBaseSelectorEventLoop(self.selector) self.set_event_loop(self.loop, cleanup=False) def test_make_socket_transport(self): @@ -303,63 +304,92 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) self.assertIsInstance(f, asyncio.Future) self.assertEqual( - (f, False, sock, ('127.0.0.1', 8080)), + (f, sock, ('127.0.0.1', 8080)), self.loop._sock_connect.call_args[0]) + def test_sock_connect_timeout(self): + # Tulip issue #205: sock_connect() must unregister the socket on + # timeout error + + # prepare mocks + self.loop.add_writer = mock.Mock() + self.loop.remove_writer = mock.Mock() + sock = test_utils.mock_nonblocking_socket() + sock.connect.side_effect = BlockingIOError + + # first call to sock_connect() registers the socket + fut = self.loop.sock_connect(sock, ('127.0.0.1', 80)) + self.assertTrue(sock.connect.called) + self.assertTrue(self.loop.add_writer.called) + self.assertEqual(len(fut._callbacks), 1) + + # on timeout, the socket must be unregistered + sock.connect.reset_mock() + fut.set_exception(asyncio.TimeoutError) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(fut) + self.assertTrue(self.loop.remove_writer.called) + def test__sock_connect(self): f = asyncio.Future(loop=self.loop) sock = mock.Mock() sock.fileno.return_value = 10 - self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect(f, sock, ('127.0.0.1', 8080)) self.assertTrue(f.done()) self.assertIsNone(f.result()) self.assertTrue(sock.connect.called) - def test__sock_connect_canceled_fut(self): + def test__sock_connect_cb_cancelled_fut(self): sock = mock.Mock() + self.loop.remove_writer = mock.Mock() f = asyncio.Future(loop=self.loop) f.cancel() - self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) - self.assertFalse(sock.connect.called) + self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.getsockopt.called) + + def test__sock_connect_writer(self): + # check that the fd is registered and then unregistered + self.loop._process_events = mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_writer = mock.Mock() - def test__sock_connect_unregister(self): sock = mock.Mock() sock.fileno.return_value = 10 + sock.connect.side_effect = BlockingIOError + sock.getsockopt.return_value = 0 + address = ('127.0.0.1', 8080) f = asyncio.Future(loop=self.loop) - f.cancel() + self.loop._sock_connect(f, sock, address) + self.assertTrue(self.loop.add_writer.called) + self.assertEqual(10, self.loop.add_writer.call_args[0][0]) - self.loop.remove_writer = mock.Mock() - self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect_cb(f, sock, address) + # need to run the event loop to execute _sock_connect_done() callback + self.loop.run_until_complete(f) self.assertEqual((10,), self.loop.remove_writer.call_args[0]) - def test__sock_connect_tryagain(self): + def test__sock_connect_cb_tryagain(self): f = asyncio.Future(loop=self.loop) sock = mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.EAGAIN - self.loop.add_writer = mock.Mock() - self.loop.remove_writer = mock.Mock() - - self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) - self.assertEqual( - (10, self.loop._sock_connect, f, - True, sock, ('127.0.0.1', 8080)), - self.loop.add_writer.call_args[0]) + # check that the exception is handled + self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) - def test__sock_connect_exception(self): + def test__sock_connect_cb_exception(self): f = asyncio.Future(loop=self.loop) sock = mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.ENOTCONN self.loop.remove_writer = mock.Mock() - self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) self.assertIsInstance(f.exception(), OSError) def test_sock_accept(self): |