diff options
-rw-r--r-- | Lib/asyncio/proactor_events.py | 8 | ||||
-rw-r--r-- | Lib/asyncio/selector_events.py | 8 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 18 |
3 files changed, 34 insertions, 0 deletions
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ab566b3..751155b 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -385,12 +385,18 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): self._selector = None def sock_recv(self, sock, n): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") return self._proactor.recv(sock, n) def sock_sendall(self, sock, data): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") return self._proactor.send(sock, data) def sock_connect(self, sock, address): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") try: base_events._check_resolved_address(sock, address) except ValueError as err: @@ -401,6 +407,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): return self._proactor.connect(sock, address) def sock_accept(self, sock): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") return self._proactor.accept(sock) def _socketpair(self): diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index eca48b8..6b7bdf0 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -256,6 +256,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) self._sock_recv(fut, False, sock, n) return fut @@ -292,6 +294,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) if data: self._sock_sendall(fut, False, sock, data) @@ -333,6 +337,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) try: base_events._check_resolved_address(sock, address) @@ -374,6 +380,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) self._sock_accept(fut, False, sock) return fut diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index b065749..0cff00a 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -383,6 +383,24 @@ class EventLoopTestsMixin: self.assertEqual(read, data) def _basetest_sock_client_ops(self, httpd, sock): + # in debug mode, socket operations must fail + # if the socket is not in blocking mode + self.loop.set_debug(True) + sock.setblocking(True) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_accept(sock)) + + # test in non-blocking mode sock.setblocking(False) self.loop.run_until_complete( self.loop.sock_connect(sock, httpd.address)) |