diff options
author | Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> | 2020-05-27 20:39:03 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-27 20:39:03 (GMT) |
commit | 3a2667d91e33493ccde113ddf5092afefc3c89fa (patch) | |
tree | cb75dad8dd392114f780a5d4ed7ab650ad0b1726 /Lib | |
parent | c011d1b5be65bb6be52de4d311b21a464fe7b0dd (diff) | |
download | cpython-3a2667d91e33493ccde113ddf5092afefc3c89fa.zip cpython-3a2667d91e33493ccde113ddf5092afefc3c89fa.tar.gz cpython-3a2667d91e33493ccde113ddf5092afefc3c89fa.tar.bz2 |
bpo-30064: Fix asyncio loop.sock_* race condition issue (GH-20369)
(cherry picked from commit 210a137396979d747c2602eeef46c34fc4955448)
Co-authored-by: Fantix King <fantix.king@gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/selector_events.py | 41 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_sock_lowlevel.py | 131 |
2 files changed, 156 insertions, 16 deletions
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index a05cbb6..884a58f 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -266,6 +266,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): (handle, writer)) if reader is not None: reader.cancel() + return handle def _remove_reader(self, fd): if self.is_closed(): @@ -302,6 +303,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): (reader, handle)) if writer is not None: writer.cancel() + return handle def _remove_writer(self, fd): """Remove a writer callback.""" @@ -329,7 +331,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def add_reader(self, fd, callback, *args): """Add a reader callback.""" self._ensure_fd_no_transport(fd) - return self._add_reader(fd, callback, *args) + self._add_reader(fd, callback, *args) def remove_reader(self, fd): """Remove a reader callback.""" @@ -339,7 +341,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def add_writer(self, fd, callback, *args): """Add a writer callback..""" self._ensure_fd_no_transport(fd) - return self._add_writer(fd, callback, *args) + self._add_writer(fd, callback, *args) def remove_writer(self, fd): """Remove a writer callback.""" @@ -362,13 +364,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): pass fut = self.create_future() fd = sock.fileno() - self.add_reader(fd, self._sock_recv, fut, sock, n) + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recv, fut, sock, n) fut.add_done_callback( - functools.partial(self._sock_read_done, fd)) + functools.partial(self._sock_read_done, fd, handle=handle)) return await fut - def _sock_read_done(self, fd, fut): - self.remove_reader(fd) + def _sock_read_done(self, fd, fut, handle=None): + if handle is None or not handle.cancelled(): + self.remove_reader(fd) def _sock_recv(self, fut, sock, n): # _sock_recv() can add itself as an I/O callback if the operation can't @@ -401,9 +405,10 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): pass fut = self.create_future() fd = sock.fileno() - self.add_reader(fd, self._sock_recv_into, fut, sock, buf) + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf) fut.add_done_callback( - functools.partial(self._sock_read_done, fd)) + functools.partial(self._sock_read_done, fd, handle=handle)) return await fut def _sock_recv_into(self, fut, sock, buf): @@ -446,11 +451,12 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): fut = self.create_future() fd = sock.fileno() - fut.add_done_callback( - functools.partial(self._sock_write_done, fd)) + self._ensure_fd_no_transport(fd) # use a trick with a list in closure to store a mutable state - self.add_writer(fd, self._sock_sendall, fut, sock, - memoryview(data), [n]) + handle = self._add_writer(fd, self._sock_sendall, fut, sock, + memoryview(data), [n]) + fut.add_done_callback( + functools.partial(self._sock_write_done, fd, handle=handle)) return await fut def _sock_sendall(self, fut, sock, view, pos): @@ -502,9 +508,11 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): # connection runs in background. We have to wait until the socket # becomes writable to be notified when the connection succeed or # fails. + self._ensure_fd_no_transport(fd) + handle = self._add_writer( + fd, self._sock_connect_cb, fut, sock, address) fut.add_done_callback( - functools.partial(self._sock_write_done, fd)) - self.add_writer(fd, self._sock_connect_cb, fut, sock, address) + functools.partial(self._sock_write_done, fd, handle=handle)) except (SystemExit, KeyboardInterrupt): raise except BaseException as exc: @@ -512,8 +520,9 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): else: fut.set_result(None) - def _sock_write_done(self, fd, fut): - self.remove_writer(fd) + def _sock_write_done(self, fd, fut, handle=None): + if handle is None or not handle.cancelled(): + self.remove_writer(fd) def _sock_connect_cb(self, fut, sock, address): if fut.done(): diff --git a/Lib/test/test_asyncio/test_sock_lowlevel.py b/Lib/test/test_asyncio/test_sock_lowlevel.py index 2f2d5a4..5e6a90a 100644 --- a/Lib/test/test_asyncio/test_sock_lowlevel.py +++ b/Lib/test/test_asyncio/test_sock_lowlevel.py @@ -1,4 +1,5 @@ import socket +import time import asyncio import sys from asyncio import proactor_events @@ -122,6 +123,136 @@ class BaseSockTestsMixin: sock = socket.socket() self._basetest_sock_recv_into(httpd, sock) + async def _basetest_sock_recv_racing(self, httpd, sock): + sock.setblocking(False) + await self.loop.sock_connect(sock, httpd.address) + + task = asyncio.create_task(self.loop.sock_recv(sock, 1024)) + await asyncio.sleep(0) + task.cancel() + + asyncio.create_task( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = await self.loop.sock_recv(sock, 1024) + # consume data + await self.loop.sock_recv(sock, 1024) + + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + async def _basetest_sock_recv_into_racing(self, httpd, sock): + sock.setblocking(False) + await self.loop.sock_connect(sock, httpd.address) + + data = bytearray(1024) + with memoryview(data) as buf: + task = asyncio.create_task( + self.loop.sock_recv_into(sock, buf[:1024])) + await asyncio.sleep(0) + task.cancel() + + task = asyncio.create_task( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + nbytes = await self.loop.sock_recv_into(sock, buf[:1024]) + # consume data + await self.loop.sock_recv_into(sock, buf[nbytes:]) + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + await task + + async def _basetest_sock_send_racing(self, listener, sock): + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + + # make connection + sock.setblocking(False) + task = asyncio.create_task( + self.loop.sock_connect(sock, listener.getsockname())) + await asyncio.sleep(0) + server = listener.accept()[0] + server.setblocking(False) + + with server: + await task + + # fill the buffer + with self.assertRaises(BlockingIOError): + while True: + sock.send(b' ' * 5) + + # cancel a blocked sock_sendall + task = asyncio.create_task( + self.loop.sock_sendall(sock, b'hello')) + await asyncio.sleep(0) + task.cancel() + + # clear the buffer + async def recv_until(): + data = b'' + while not data: + data = await self.loop.sock_recv(server, 1024) + data = data.strip() + return data + task = asyncio.create_task(recv_until()) + + # immediately register another sock_sendall + await self.loop.sock_sendall(sock, b'world') + data = await task + # ProactorEventLoop could deliver hello + self.assertTrue(data.endswith(b'world')) + + async def _basetest_sock_connect_racing(self, listener, sock): + listener.bind(('127.0.0.1', 0)) + addr = listener.getsockname() + sock.setblocking(False) + + task = asyncio.create_task(self.loop.sock_connect(sock, addr)) + await asyncio.sleep(0) + task.cancel() + + listener.listen(1) + i = 0 + while True: + try: + await self.loop.sock_connect(sock, addr) + break + except ConnectionRefusedError: # on Linux we need another retry + await self.loop.sock_connect(sock, addr) + break + except OSError as e: # on Windows we need more retries + # A connect request was made on an already connected socket + if getattr(e, 'winerror', 0) == 10056: + break + + # https://stackoverflow.com/a/54437602/3316267 + if getattr(e, 'winerror', 0) != 10022: + raise + i += 1 + if i >= 128: + raise # too many retries + # avoid touching event loop to maintain race condition + time.sleep(0.01) + + def test_sock_client_racing(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + with sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_recv_racing(httpd, sock), 10)) + sock = socket.socket() + with sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_recv_into_racing(httpd, sock), 10)) + listener = socket.socket() + sock = socket.socket() + with listener, sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_send_racing(listener, sock), 10)) + listener = socket.socket() + sock = socket.socket() + with listener, sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_connect_racing(listener, sock), 10)) + async def _basetest_huge_content(self, address): sock = socket.socket() sock.setblocking(False) |