diff options
| author | Victor Stinner <victor.stinner@gmail.com> | 2014-08-31 13:08:21 (GMT) | 
|---|---|---|
| committer | Victor Stinner <victor.stinner@gmail.com> | 2014-08-31 13:08:21 (GMT) | 
| commit | de993bd9b68f1a1c2a3208e2024c94f99eb6cd05 (patch) | |
| tree | ccbe623f58be09f2e834cce3b4ab86e07d135bb9 /Lib | |
| parent | 41c13ce50ad07847ffb609fcabdb25a07b897db4 (diff) | |
| parent | d5aeccf9767c1619faa29e8ed61c93bde7bc5e3f (diff) | |
| download | cpython-de993bd9b68f1a1c2a3208e2024c94f99eb6cd05.zip cpython-de993bd9b68f1a1c2a3208e2024c94f99eb6cd05.tar.gz cpython-de993bd9b68f1a1c2a3208e2024c94f99eb6cd05.tar.bz2  | |
(Merge 3.4) asyncio, Tulip issue 205: Fix a race condition in
BaseSelectorEventLoop.sock_connect()
There is a race condition in create_connection() used with wait_for() to have a
timeout. sock_connect() registers the file descriptor of the socket to be
notified of write event (if connect() raises BlockingIOError). When
create_connection() is cancelled with a TimeoutError, sock_connect() coroutine
gets the exception, but it doesn't unregister the file descriptor for write
event. create_connection() gets the TimeoutError and closes the socket.
If you call again create_connection(), the new socket will likely gets the same
file descriptor, which is still registered in the selector. When sock_connect()
calls add_writer(), it tries to modify the entry instead of creating a new one.
This issue was originally reported in the Trollius project, but the bug comes
from Tulip in fact (Trollius is based on Tulip):
https://bitbucket.org/enovance/trollius/issue/15/after-timeouterror-on-wait_for
This change fixes the race condition. It also makes sock_connect() more
reliable (and portable) is sock.connect() raises an InterruptedError.
Diffstat (limited to 'Lib')
| -rw-r--r-- | Lib/asyncio/selector_events.py | 44 | ||||
| -rw-r--r-- | Lib/test/test_asyncio/test_selector_events.py | 74 | 
2 files changed, 83 insertions, 35 deletions
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 0434a70..33de92e 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -8,6 +8,7 @@ __all__ = ['BaseSelectorEventLoop']  import collections  import errno +import functools  import socket  try:      import ssl @@ -345,26 +346,43 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):          except ValueError as err:              fut.set_exception(err)          else: -            self._sock_connect(fut, False, sock, address) +            self._sock_connect(fut, sock, address)          return fut -    def _sock_connect(self, fut, registered, sock, address): +    def _sock_connect(self, fut, sock, address):          fd = sock.fileno() -        if registered: -            self.remove_writer(fd) +        try: +            while True: +                try: +                    sock.connect(address) +                except InterruptedError: +                    continue +                else: +                    break +        except BlockingIOError: +            fut.add_done_callback(functools.partial(self._sock_connect_done, +                                                    sock)) +            self.add_writer(fd, self._sock_connect_cb, fut, sock, address) +        except Exception as exc: +            fut.set_exception(exc) +        else: +            fut.set_result(None) + +    def _sock_connect_done(self, sock, fut): +        self.remove_writer(sock.fileno()) + +    def _sock_connect_cb(self, fut, sock, address):          if fut.cancelled():              return +          try: -            if not registered: -                # First time around. -                sock.connect(address) -            else: -                err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) -                if err != 0: -                    # Jump to the except clause below. -                    raise OSError(err, 'Connect call failed %s' % (address,)) +            err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) +            if err != 0: +                # Jump to any except clause below. +                raise OSError(err, 'Connect call failed %s' % (address,))          except (BlockingIOError, InterruptedError): -            self.add_writer(fd, self._sock_connect, fut, True, sock, address) +            # socket is still registered, the callback will be retried later +            pass          except Exception as exc:              fut.set_exception(exc)          else: diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index df6e991..528da39 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -40,8 +40,9 @@ def list_to_buffer(l=()):  class BaseSelectorEventLoopTests(test_utils.TestCase):      def setUp(self): -        selector = mock.Mock() -        self.loop = TestBaseSelectorEventLoop(selector) +        self.selector = mock.Mock() +        self.selector.select.return_value = [] +        self.loop = TestBaseSelectorEventLoop(self.selector)          self.set_event_loop(self.loop, cleanup=False)      def test_make_socket_transport(self): @@ -303,63 +304,92 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):          f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))          self.assertIsInstance(f, asyncio.Future)          self.assertEqual( -            (f, False, sock, ('127.0.0.1', 8080)), +            (f, sock, ('127.0.0.1', 8080)),              self.loop._sock_connect.call_args[0]) +    def test_sock_connect_timeout(self): +        # Tulip issue #205: sock_connect() must unregister the socket on +        # timeout error + +        # prepare mocks +        self.loop.add_writer = mock.Mock() +        self.loop.remove_writer = mock.Mock() +        sock = test_utils.mock_nonblocking_socket() +        sock.connect.side_effect = BlockingIOError + +        # first call to sock_connect() registers the socket +        fut = self.loop.sock_connect(sock, ('127.0.0.1', 80)) +        self.assertTrue(sock.connect.called) +        self.assertTrue(self.loop.add_writer.called) +        self.assertEqual(len(fut._callbacks), 1) + +        # on timeout, the socket must be unregistered +        sock.connect.reset_mock() +        fut.set_exception(asyncio.TimeoutError) +        with self.assertRaises(asyncio.TimeoutError): +            self.loop.run_until_complete(fut) +        self.assertTrue(self.loop.remove_writer.called) +      def test__sock_connect(self):          f = asyncio.Future(loop=self.loop)          sock = mock.Mock()          sock.fileno.return_value = 10 -        self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) +        self.loop._sock_connect(f, sock, ('127.0.0.1', 8080))          self.assertTrue(f.done())          self.assertIsNone(f.result())          self.assertTrue(sock.connect.called) -    def test__sock_connect_canceled_fut(self): +    def test__sock_connect_cb_cancelled_fut(self):          sock = mock.Mock() +        self.loop.remove_writer = mock.Mock()          f = asyncio.Future(loop=self.loop)          f.cancel() -        self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) -        self.assertFalse(sock.connect.called) +        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) +        self.assertFalse(sock.getsockopt.called) + +    def test__sock_connect_writer(self): +        # check that the fd is registered and then unregistered +        self.loop._process_events = mock.Mock() +        self.loop.add_writer = mock.Mock() +        self.loop.remove_writer = mock.Mock() -    def test__sock_connect_unregister(self):          sock = mock.Mock()          sock.fileno.return_value = 10 +        sock.connect.side_effect = BlockingIOError +        sock.getsockopt.return_value = 0 +        address = ('127.0.0.1', 8080)          f = asyncio.Future(loop=self.loop) -        f.cancel() +        self.loop._sock_connect(f, sock, address) +        self.assertTrue(self.loop.add_writer.called) +        self.assertEqual(10, self.loop.add_writer.call_args[0][0]) -        self.loop.remove_writer = mock.Mock() -        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) +        self.loop._sock_connect_cb(f, sock, address) +        # need to run the event loop to execute _sock_connect_done() callback +        self.loop.run_until_complete(f)          self.assertEqual((10,), self.loop.remove_writer.call_args[0]) -    def test__sock_connect_tryagain(self): +    def test__sock_connect_cb_tryagain(self):          f = asyncio.Future(loop=self.loop)          sock = mock.Mock()          sock.fileno.return_value = 10          sock.getsockopt.return_value = errno.EAGAIN -        self.loop.add_writer = mock.Mock() -        self.loop.remove_writer = mock.Mock() - -        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) -        self.assertEqual( -            (10, self.loop._sock_connect, f, -             True, sock, ('127.0.0.1', 8080)), -            self.loop.add_writer.call_args[0]) +        # check that the exception is handled +        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) -    def test__sock_connect_exception(self): +    def test__sock_connect_cb_exception(self):          f = asyncio.Future(loop=self.loop)          sock = mock.Mock()          sock.fileno.return_value = 10          sock.getsockopt.return_value = errno.ENOTCONN          self.loop.remove_writer = mock.Mock() -        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) +        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))          self.assertIsInstance(f.exception(), OSError)      def test_sock_accept(self):  | 
