summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/asyncio/base_events.py6
-rw-r--r--Lib/asyncio/test_utils.py5
-rw-r--r--Lib/test/test_asyncio/test_base_events.py18
-rw-r--r--Lib/test/test_asyncio/test_proactor_events.py2
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py9
-rw-r--r--Lib/test/test_asyncio/test_tasks.py12
6 files changed, 39 insertions, 13 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 48b3ee3..4b7b161 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -227,6 +227,8 @@ class BaseEventLoop(events.AbstractEventLoop):
def call_at(self, when, callback, *args):
"""Like call_later(), but uses an absolute time."""
+ if tasks.iscoroutinefunction(callback):
+ raise TypeError("coroutines cannot be used with call_at()")
timer = events.TimerHandle(when, callback, args)
heapq.heappush(self._scheduled, timer)
return timer
@@ -241,6 +243,8 @@ class BaseEventLoop(events.AbstractEventLoop):
Any positional arguments after the callback will be passed to
the callback when it is called.
"""
+ if tasks.iscoroutinefunction(callback):
+ raise TypeError("coroutines cannot be used with call_soon()")
handle = events.Handle(callback, args)
self._ready.append(handle)
return handle
@@ -252,6 +256,8 @@ class BaseEventLoop(events.AbstractEventLoop):
return handle
def run_in_executor(self, executor, callback, *args):
+ if tasks.iscoroutinefunction(callback):
+ raise TypeError("coroutines cannot be used with run_in_executor()")
if isinstance(callback, events.Handle):
assert not args
assert not isinstance(callback, events.TimerHandle)
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py
index 7c8e1dc..deab7c3 100644
--- a/Lib/asyncio/test_utils.py
+++ b/Lib/asyncio/test_utils.py
@@ -135,7 +135,7 @@ def make_test_protocol(base):
if name.startswith('__') and name.endswith('__'):
# skip magic names
continue
- dct[name] = unittest.mock.Mock(return_value=None)
+ dct[name] = MockCallback(return_value=None)
return type('TestProtocol', (base,) + base.__bases__, dct)()
@@ -274,3 +274,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self):
pass
+
+def MockCallback(**kwargs):
+ return unittest.mock.Mock(spec=['__call__'], **kwargs)
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index 5b05684..c6950ab 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -567,6 +567,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
m_socket.getaddrinfo.return_value = [
(2, 1, 6, '', ('127.0.0.1', 10100))]
+ m_socket.getaddrinfo._is_coroutine = False
m_sock = m_socket.socket.return_value = unittest.mock.Mock()
m_sock.bind.side_effect = Err
@@ -577,6 +578,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
@unittest.mock.patch('asyncio.base_events.socket')
def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
m_socket.getaddrinfo.return_value = []
+ m_socket.getaddrinfo._is_coroutine = False
coro = self.loop.create_datagram_endpoint(
MyDatagramProto, local_addr=('localhost', 0))
@@ -681,6 +683,22 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
unittest.mock.ANY,
MyProto, sock, None, None)
+ def test_call_coroutine(self):
+ @asyncio.coroutine
+ def coroutine_function():
+ pass
+
+ with self.assertRaises(TypeError):
+ self.loop.call_soon(coroutine_function)
+ with self.assertRaises(TypeError):
+ self.loop.call_soon_threadsafe(coroutine_function)
+ with self.assertRaises(TypeError):
+ self.loop.call_later(60, coroutine_function)
+ with self.assertRaises(TypeError):
+ self.loop.call_at(self.loop.time() + 60, coroutine_function)
+ with self.assertRaises(TypeError):
+ self.loop.run_in_executor(None, coroutine_function)
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
index 9964f42..6bea1a3 100644
--- a/Lib/test/test_asyncio/test_proactor_events.py
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -402,7 +402,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
NotImplementedError, BaseProactorEventLoop, self.proactor)
def test_make_socket_transport(self):
- tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
+ tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
self.assertIsInstance(tr, _ProactorSocketTransport)
def test_loop_self_reading(self):
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
index ad0b0be..855a895 100644
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -44,8 +44,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test_make_socket_transport(self):
m = unittest.mock.Mock()
self.loop.add_reader = unittest.mock.Mock()
- self.assertIsInstance(
- self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
+ transport = self.loop._make_socket_transport(m, asyncio.Protocol())
+ self.assertIsInstance(transport, _SelectorSocketTransport)
@unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self):
@@ -54,8 +54,9 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_writer = unittest.mock.Mock()
self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock()
- self.assertIsInstance(
- self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
+ waiter = asyncio.Future(loop=self.loop)
+ transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
+ self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None)
def test_make_ssl_transport_without_ssl_error(self):
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 9abdfa5..29bdaf5 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -2,8 +2,6 @@
import gc
import unittest
-import unittest.mock
-from unittest.mock import Mock
import asyncio
from asyncio import test_utils
@@ -1358,7 +1356,7 @@ class GatherTestsBase:
def _check_success(self, **kwargs):
a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
- cb = Mock()
+ cb = test_utils.MockCallback()
fut.add_done_callback(cb)
b.set_result(1)
a.set_result(2)
@@ -1380,7 +1378,7 @@ class GatherTestsBase:
def test_one_exception(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
- cb = Mock()
+ cb = test_utils.MockCallback()
fut.add_done_callback(cb)
exc = ZeroDivisionError()
a.set_result(1)
@@ -1399,7 +1397,7 @@ class GatherTestsBase:
a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
return_exceptions=True)
- cb = Mock()
+ cb = test_utils.MockCallback()
fut.add_done_callback(cb)
exc = ZeroDivisionError()
exc2 = RuntimeError()
@@ -1460,7 +1458,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
def test_one_cancellation(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(a, b, c, d, e)
- cb = Mock()
+ cb = test_utils.MockCallback()
fut.add_done_callback(cb)
a.set_result(1)
b.cancel()
@@ -1479,7 +1477,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
for i in range(6)]
fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
- cb = Mock()
+ cb = test_utils.MockCallback()
fut.add_done_callback(cb)
a.set_result(1)
zde = ZeroDivisionError()