summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVictor Stinner <victor.stinner@gmail.com>2014-07-29 21:08:17 (GMT)
committerVictor Stinner <victor.stinner@gmail.com>2014-07-29 21:08:17 (GMT)
commit9c9f1f10d391ff3458a80fc3d0110870d50012e2 (patch)
tree91366a85b6b7bb730cce303233aa0914aff8dbd9
parentf2ed889027f23ecf184a6ef52b7a7cc4921c3000 (diff)
downloadcpython-9c9f1f10d391ff3458a80fc3d0110870d50012e2.zip
cpython-9c9f1f10d391ff3458a80fc3d0110870d50012e2.tar.gz
cpython-9c9f1f10d391ff3458a80fc3d0110870d50012e2.tar.bz2
Close #22063: socket operations (socket,recv, sock_sendall, sock_connect,
sock_accept) now raise an exception in debug mode if sockets are in blocking mode.
-rw-r--r--Lib/asyncio/proactor_events.py8
-rw-r--r--Lib/asyncio/selector_events.py8
-rw-r--r--Lib/test/test_asyncio/test_events.py18
3 files changed, 34 insertions, 0 deletions
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index ab566b3..751155b 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -385,12 +385,18 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._selector = None
def sock_recv(self, sock, n):
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
return self._proactor.recv(sock, n)
def sock_sendall(self, sock, data):
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
return self._proactor.send(sock, data)
def sock_connect(self, sock, address):
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
try:
base_events._check_resolved_address(sock, address)
except ValueError as err:
@@ -401,6 +407,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
return self._proactor.connect(sock, address)
def sock_accept(self, sock):
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
return self._proactor.accept(sock)
def _socketpair(self):
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index eca48b8..6b7bdf0 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -256,6 +256,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
This method is a coroutine.
"""
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self)
self._sock_recv(fut, False, sock, n)
return fut
@@ -292,6 +294,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
This method is a coroutine.
"""
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self)
if data:
self._sock_sendall(fut, False, sock, data)
@@ -333,6 +337,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
This method is a coroutine.
"""
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self)
try:
base_events._check_resolved_address(sock, address)
@@ -374,6 +380,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
This method is a coroutine.
"""
+ if self.get_debug() and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self)
self._sock_accept(fut, False, sock)
return fut
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index b065749..0cff00a 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -383,6 +383,24 @@ class EventLoopTestsMixin:
self.assertEqual(read, data)
def _basetest_sock_client_ops(self, httpd, sock):
+ # in debug mode, socket operations must fail
+ # if the socket is not in blocking mode
+ self.loop.set_debug(True)
+ sock.setblocking(True)
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, httpd.address))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ with self.assertRaises(ValueError):
+ self.loop.run_until_complete(
+ self.loop.sock_accept(sock))
+
+ # test in non-blocking mode
sock.setblocking(False)
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))