diff options
author | Guido van Rossum <guido@dropbox.com> | 2013-10-17 20:40:50 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@dropbox.com> | 2013-10-17 20:40:50 (GMT) |
commit | 27b7c7ebf1039e96cac41b6330cf16b5632d9e49 (patch) | |
tree | 814505b0f9d02a5cabdec733dcde70250b04ee28 /Lib/test/test_asyncio | |
parent | 5b37f97ea5ac9f6b33b0e0269c69539cbb478142 (diff) | |
download | cpython-27b7c7ebf1039e96cac41b6330cf16b5632d9e49.zip cpython-27b7c7ebf1039e96cac41b6330cf16b5632d9e49.tar.gz cpython-27b7c7ebf1039e96cac41b6330cf16b5632d9e49.tar.bz2 |
Initial checkin of asyncio package (== Tulip, == PEP 3156).
Diffstat (limited to 'Lib/test/test_asyncio')
22 files changed, 8864 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py new file mode 100644 index 0000000..ec483ea --- /dev/null +++ b/Lib/test/test_asyncio/__init__.py @@ -0,0 +1,26 @@ +import os +import sys +import unittest +from test.support import run_unittest + + +def suite(): + tests_file = os.path.join(os.path.dirname(__file__), 'tests.txt') + with open(tests_file) as fp: + test_names = fp.read().splitlines() + tests = unittest.TestSuite() + loader = unittest.TestLoader() + for test_name in test_names: + mod_name = 'test.' + test_name + try: + __import__(mod_name) + except unittest.SkipTest: + pass + else: + mod = sys.modules[mod_name] + tests.addTests(loader.loadTestsFromModule(mod)) + return tests + + +def test_main(): + run_unittest(suite()) diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py new file mode 100644 index 0000000..b549492 --- /dev/null +++ b/Lib/test/test_asyncio/__main__.py @@ -0,0 +1,5 @@ +from . import test_main + + +if __name__ == '__main__': + test_main() diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py new file mode 100644 index 0000000..f6ac0a3 --- /dev/null +++ b/Lib/test/test_asyncio/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py new file mode 100644 index 0000000..e83ca09 --- /dev/null +++ b/Lib/test/test_asyncio/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py new file mode 100644 index 0000000..f1f7ea7 --- /dev/null +++ b/Lib/test/test_asyncio/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/Lib/test/test_asyncio/sample.crt b/Lib/test/test_asyncio/sample.crt new file mode 100644 index 0000000..6a1e3f3 --- /dev/null +++ b/Lib/test/test_asyncio/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- diff --git a/Lib/test/test_asyncio/sample.key b/Lib/test/test_asyncio/sample.key new file mode 100644 index 0000000..edfea8d --- /dev/null +++ b/Lib/test/test_asyncio/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py new file mode 100644 index 0000000..d48d12c --- /dev/null +++ b/Lib/test/test_asyncio/test_base_events.py @@ -0,0 +1,590 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from asyncio import base_events +from asyncio import events +from asyncio import futures +from asyncio import protocols +from asyncio import tasks +from asyncio import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('asyncio.base_events.time') + @unittest.mock.patch('asyncio.base_events.asyncio_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py new file mode 100644 index 0000000..243f400 --- /dev/null +++ b/Lib/test/test_asyncio/test_events.py @@ -0,0 +1,1573 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from asyncio import futures +from asyncio import events +from asyncio import transports +from asyncio import protocols +from asyncio import selector_events +from asyncio import tasks +from asyncio import test_utils +from asyncio import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.Event(loop=loop), + 2: locks.Event(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, + ssl=test_utils.dummy_ssl_context()) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_create_server(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.create_server(factory, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.create_server( + factory, '127.0.0.1', 0, ssl=sslcontext) + + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, + ssl=test_utils.dummy_ssl_context()) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + def test_create_server_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_create_server_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('sockname') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + for _ in range(1000): + if server.nbytes: + break + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + for _ in range(1000): + if client.nbytes: + break + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + @tasks.coroutine + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from asyncio import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from asyncio import selectors + from asyncio import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())<cancelled>'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('asyncio.events.asyncio_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())<cancelled>'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.create_server, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('asyncio.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py new file mode 100644 index 0000000..9b5108c --- /dev/null +++ b/Lib/test/test_asyncio/test_futures.py @@ -0,0 +1,329 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future<PENDING>') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future<result=4>') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future<PENDING, [<function _fakefunc', + repr(f_few_callbacks)) + f_few_callbacks.cancel() + + f_many_callbacks = futures.Future(loop=self.loop) + for i in range(20): + f_many_callbacks.add_done_callback(_fakefunc) + r = repr(f_many_callbacks) + self.assertIn('Future<PENDING, [<function _fakefunc', r) + self.assertIn('<18 more>', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('asyncio.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=self.loop) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py new file mode 100644 index 0000000..31b4d64 --- /dev/null +++ b/Lib/test/test_asyncio/test_locks.py @@ -0,0 +1,765 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import locks +from asyncio import tasks +from asyncio import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.Event(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.Event(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.Event() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.Event(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.Event(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = locks.Event(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.Event(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py new file mode 100644 index 0000000..c52ade0 --- /dev/null +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -0,0 +1,480 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import asyncio +from asyncio.proactor_events import BaseProactorEventLoop +from asyncio.proactor_events import _ProactorSocketTransport +from asyncio.proactor_events import _ProactorWritePipeTransport +from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = asyncio.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + tr.close() + + def test_pause_resume(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = asyncio.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data1') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.pause() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.resume() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') + tr.close() + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_create_server(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_create_server_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = asyncio.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop._stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor._stop_serving.assert_called_with(sock) diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py new file mode 100644 index 0000000..8af4ee7 --- /dev/null +++ b/Lib/test/test_asyncio/test_queues.py @@ -0,0 +1,470 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import locks +from asyncio import queues +from asyncio import tasks +from asyncio import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith('<Queue'), fn(q)) + id_is_present = hex(id(q)) in fn(q) + self.assertEqual(expect_id, id_is_present) + + @tasks.coroutine + def add_getter(): + q = queues.Queue(loop=loop) + # Start a task that waits to get. + tasks.Task(q.get(), loop=loop) + # Let it start waiting. + yield from tasks.sleep(0.1, loop=loop) + self.assertTrue('_getters[1]' in fn(q)) + # resume q.get coroutine to finish generator + q.put_nowait(0) + + loop.run_until_complete(add_getter()) + + @tasks.coroutine + def add_putter(): + q = queues.Queue(maxsize=1, loop=loop) + q.put_nowait(1) + # Start a task that waits to put. + tasks.Task(q.put(2), loop=loop) + # Let it start waiting. + yield from tasks.sleep(0.1, loop=loop) + self.assertTrue('_putters[1]' in fn(q)) + # resume q.put coroutine to finish generator + q.get_nowait() + + loop.run_until_complete(add_putter()) + + q = queues.Queue(loop=loop) + q.put_nowait(1) + self.assertTrue('_queue=[1]' in fn(q)) + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + q = queues.Queue(loop=loop) + self.assertIs(q._loop, loop) + + q = queues.Queue(loop=self.loop) + self.assertIs(q._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + q = queues.Queue() + self.assertIs(q._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + self._test_repr_or_str(repr, True) + + def test_str(self): + self._test_repr_or_str(str, False) + + def test_empty(self): + q = queues.Queue(loop=self.loop) + self.assertTrue(q.empty()) + q.put_nowait(1) + self.assertFalse(q.empty()) + self.assertEqual(1, q.get_nowait()) + self.assertTrue(q.empty()) + + def test_full(self): + q = queues.Queue(loop=self.loop) + self.assertFalse(q.full()) + + q = queues.Queue(maxsize=1, loop=self.loop) + q.put_nowait(1) + self.assertTrue(q.full()) + + def test_order(self): + q = queues.Queue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 3, 2], items) + + def test_maxsize(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0.01 + self.assertAlmostEqual(0.02, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(maxsize=2, loop=loop) + self.assertEqual(2, q.maxsize) + have_been_put = [] + + @tasks.coroutine + def putter(): + for i in range(3): + yield from q.put(i) + have_been_put.append(i) + return True + + @tasks.coroutine + def test(): + t = tasks.Task(putter(), loop=loop) + yield from tasks.sleep(0.01, loop=loop) + + # The putter is blocked after putting two items. + self.assertEqual([0, 1], have_been_put) + self.assertEqual(0, q.get_nowait()) + + # Let the putter resume and put last item. + yield from tasks.sleep(0.01, loop=loop) + self.assertEqual([0, 1, 2], have_been_put) + self.assertEqual(1, q.get_nowait()) + self.assertEqual(2, q.get_nowait()) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + loop.run_until_complete(test()) + self.assertAlmostEqual(0.02, loop.time()) + + +class QueueGetTests(_QueueTestBase): + + def test_blocking_get(self): + q = queues.Queue(loop=self.loop) + q.put_nowait(1) + + @tasks.coroutine + def queue_get(): + return (yield from q.get()) + + res = self.loop.run_until_complete(queue_get()) + self.assertEqual(1, res) + + def test_get_with_putters(self): + q = queues.Queue(1, loop=self.loop) + q.put_nowait(1) + + waiter = futures.Future(loop=self.loop) + q._putters.append((2, waiter)) + + res = self.loop.run_until_complete(q.get()) + self.assertEqual(1, res) + self.assertTrue(waiter.done()) + self.assertIsNone(waiter.result()) + + def test_blocking_get_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + started = locks.Event(loop=loop) + finished = False + + @tasks.coroutine + def queue_get(): + nonlocal finished + started.set() + res = yield from q.get() + finished = True + return res + + @tasks.coroutine + def queue_put(): + loop.call_later(0.01, q.put_nowait, 1) + queue_get_task = tasks.Task(queue_get(), loop=loop) + yield from started.wait() + self.assertFalse(finished) + res = yield from queue_get_task + self.assertTrue(finished) + return res + + res = loop.run_until_complete(queue_put()) + self.assertEqual(1, res) + self.assertAlmostEqual(0.01, loop.time()) + + def test_nonblocking_get(self): + q = queues.Queue(loop=self.loop) + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_get_exception(self): + q = queues.Queue(loop=self.loop) + self.assertRaises(queues.Empty, q.get_nowait) + + def test_get_cancelled(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0.01 + self.assertAlmostEqual(0.061, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + + @tasks.coroutine + def queue_get(): + return (yield from tasks.wait_for(q.get(), 0.051, loop=loop)) + + @tasks.coroutine + def test(): + get_task = tasks.Task(queue_get(), loop=loop) + yield from tasks.sleep(0.01, loop=loop) # let the task start + q.put_nowait(1) + return (yield from get_task) + + self.assertEqual(1, loop.run_until_complete(test())) + self.assertAlmostEqual(0.06, loop.time()) + + def test_get_cancelled_race(self): + q = queues.Queue(loop=self.loop) + + t1 = tasks.Task(q.get(), loop=self.loop) + t2 = tasks.Task(q.get(), loop=self.loop) + + test_utils.run_briefly(self.loop) + t1.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t1.done()) + q.put_nowait('a') + test_utils.run_briefly(self.loop) + self.assertEqual(t2.result(), 'a') + + def test_get_with_waiting_putters(self): + q = queues.Queue(loop=self.loop, maxsize=1) + tasks.Task(q.put('a'), loop=self.loop) + tasks.Task(q.put('b'), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(self.loop.run_until_complete(q.get()), 'a') + self.assertEqual(self.loop.run_until_complete(q.get()), 'b') + + +class QueuePutTests(_QueueTestBase): + + def test_blocking_put(self): + q = queues.Queue(loop=self.loop) + + @tasks.coroutine + def queue_put(): + # No maxsize, won't block. + yield from q.put(1) + + self.loop.run_until_complete(queue_put()) + + def test_blocking_put_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(maxsize=1, loop=loop) + started = locks.Event(loop=loop) + finished = False + + @tasks.coroutine + def queue_put(): + nonlocal finished + started.set() + yield from q.put(1) + yield from q.put(2) + finished = True + + @tasks.coroutine + def queue_get(): + loop.call_later(0.01, q.get_nowait) + queue_put_task = tasks.Task(queue_put(), loop=loop) + yield from started.wait() + self.assertFalse(finished) + yield from queue_put_task + self.assertTrue(finished) + + loop.run_until_complete(queue_get()) + self.assertAlmostEqual(0.01, loop.time()) + + def test_nonblocking_put(self): + q = queues.Queue(loop=self.loop) + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_put_exception(self): + q = queues.Queue(maxsize=1, loop=self.loop) + q.put_nowait(1) + self.assertRaises(queues.Full, q.put_nowait, 2) + + def test_put_cancelled(self): + q = queues.Queue(loop=self.loop) + + @tasks.coroutine + def queue_put(): + yield from q.put(1) + return True + + @tasks.coroutine + def test(): + return (yield from q.get()) + + t = tasks.Task(queue_put(), loop=self.loop) + self.assertEqual(1, self.loop.run_until_complete(test())) + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_put_cancelled_race(self): + q = queues.Queue(loop=self.loop, maxsize=1) + + tasks.Task(q.put('a'), loop=self.loop) + tasks.Task(q.put('c'), loop=self.loop) + t = tasks.Task(q.put('b'), loop=self.loop) + + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t.done()) + self.assertEqual(q.get_nowait(), 'a') + self.assertEqual(q.get_nowait(), 'c') + + def test_put_with_waiting_getters(self): + q = queues.Queue(loop=self.loop) + t = tasks.Task(q.get(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.loop.run_until_complete(q.put('a')) + self.assertEqual(self.loop.run_until_complete(t), 'a') + + +class LifoQueueTests(_QueueTestBase): + + def test_order(self): + q = queues.LifoQueue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([2, 3, 1], items) + + +class PriorityQueueTests(_QueueTestBase): + + def test_order(self): + q = queues.PriorityQueue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 2, 3], items) + + +class JoinableQueueTests(_QueueTestBase): + + def test_task_done_underflow(self): + q = queues.JoinableQueue(loop=self.loop) + self.assertRaises(ValueError, q.task_done) + + def test_task_done(self): + q = queues.JoinableQueue(loop=self.loop) + for i in range(100): + q.put_nowait(i) + + accumulator = 0 + + # Two workers get items from the queue and call task_done after each. + # Join the queue and assert all items have been processed. + running = True + + @tasks.coroutine + def worker(): + nonlocal accumulator + + while running: + item = yield from q.get() + accumulator += item + q.task_done() + + @tasks.coroutine + def test(): + for _ in range(2): + tasks.Task(worker(), loop=self.loop) + + yield from q.join() + + self.loop.run_until_complete(test()) + self.assertEqual(sum(range(100)), accumulator) + + # close running generators + running = False + for i in range(2): + q.put_nowait(0) + + def test_join_empty_queue(self): + q = queues.JoinableQueue(loop=self.loop) + + # Test that a queue join()s successfully, and before anything else + # (done twice for insurance). + + @tasks.coroutine + def join(): + yield from q.join() + yield from q.join() + + self.loop.run_until_complete(join()) + + def test_format(self): + q = queues.JoinableQueue(loop=self.loop) + self.assertEqual(q._format(), 'maxsize=0') + + q._unfinished_tasks = 2 + self.assertEqual(q._format(), 'maxsize=0 tasks=2') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py new file mode 100644 index 0000000..0225e13 --- /dev/null +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -0,0 +1,1485 @@ +"""Tests for selector_events.py""" + +import collections +import errno +import gc +import pprint +import socket +import sys +import unittest +import unittest.mock +try: + import ssl +except ImportError: + ssl = None + +from asyncio import futures +from asyncio import selectors +from asyncio import test_utils +from asyncio.protocols import DatagramProtocol, Protocol +from asyncio.selector_events import BaseSelectorEventLoop +from asyncio.selector_events import _SelectorTransport +from asyncio.selector_events import _SelectorSslTransport +from asyncio.selector_events import _SelectorSocketTransport +from asyncio.selector_events import _SelectorDatagramTransport + + +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + self._ssock = unittest.mock.Mock() + self._csock = unittest.mock.Mock() + self._internal_fds += 1 + + +class BaseSelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) + + def test_make_socket_transport(self): + m = unittest.mock.Mock() + self.loop.add_reader = unittest.mock.Mock() + self.assertIsInstance( + self.loop._make_socket_transport(m, m), _SelectorSocketTransport) + + def test_make_ssl_transport(self): + m = unittest.mock.Mock() + self.loop.add_reader = unittest.mock.Mock() + self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_reader = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() + self.assertIsInstance( + self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) + + def test_close(self): + ssock = self.loop._ssock + ssock.fileno.return_value = 7 + csock = self.loop._csock + csock.fileno.return_value = 1 + remove_reader = self.loop.remove_reader = unittest.mock.Mock() + + self.loop._selector.close() + self.loop._selector = selector = unittest.mock.Mock() + self.loop.close() + self.assertIsNone(self.loop._selector) + self.assertIsNone(self.loop._csock) + self.assertIsNone(self.loop._ssock) + selector.close.assert_called_with() + ssock.close.assert_called_with() + csock.close.assert_called_with() + remove_reader.assert_called_with(7) + + self.loop.close() + self.loop.close() + + def test_close_no_selector(self): + ssock = self.loop._ssock + csock = self.loop._csock + remove_reader = self.loop.remove_reader = unittest.mock.Mock() + + self.loop._selector.close() + self.loop._selector = None + self.loop.close() + self.assertIsNone(self.loop._selector) + self.assertFalse(ssock.close.called) + self.assertFalse(csock.close.called) + self.assertFalse(remove_reader.called) + + def test_socketpair(self): + self.assertRaises(NotImplementedError, self.loop._socketpair) + + def test_read_from_self_tryagain(self): + self.loop._ssock.recv.side_effect = BlockingIOError + self.assertIsNone(self.loop._read_from_self()) + + def test_read_from_self_exception(self): + self.loop._ssock.recv.side_effect = OSError + self.assertRaises(OSError, self.loop._read_from_self) + + def test_write_to_self_tryagain(self): + self.loop._csock.send.side_effect = BlockingIOError + self.assertIsNone(self.loop._write_to_self()) + + def test_write_to_self_exception(self): + self.loop._csock.send.side_effect = OSError() + self.assertRaises(OSError, self.loop._write_to_self) + + def test_sock_recv(self): + sock = unittest.mock.Mock() + self.loop._sock_recv = unittest.mock.Mock() + + f = self.loop.sock_recv(sock, 1024) + self.assertIsInstance(f, futures.Future) + self.loop._sock_recv.assert_called_with(f, False, sock, 1024) + + def test__sock_recv_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_recv(f, False, sock, 1024) + self.assertFalse(sock.recv.called) + + def test__sock_recv_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_reader = unittest.mock.Mock() + self.loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) + + def test__sock_recv_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = BlockingIOError + + self.loop.add_reader = unittest.mock.Mock() + self.loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024), + self.loop.add_reader.call_args[0]) + + def test__sock_recv_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.recv.side_effect = OSError() + + self.loop._sock_recv(f, False, sock, 1024) + self.assertIs(err, f.exception()) + + def test_sock_sendall(self): + sock = unittest.mock.Mock() + self.loop._sock_sendall = unittest.mock.Mock() + + f = self.loop.sock_sendall(sock, b'data') + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, b'data'), + self.loop._sock_sendall.call_args[0]) + + def test_sock_sendall_nodata(self): + sock = unittest.mock.Mock() + self.loop._sock_sendall = unittest.mock.Mock() + + f = self.loop.sock_sendall(sock, b'') + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + self.assertFalse(self.loop._sock_sendall.called) + + def test__sock_sendall_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(sock.send.called) + + def test__sock_sendall_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) + + def test__sock_sendall_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = BlockingIOError + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_interrupted(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = InterruptedError + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.send.side_effect = OSError() + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertIs(f.exception(), err) + + def test__sock_sendall(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 4 + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test__sock_sendall_partial(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 2 + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'ta'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_none(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 0 + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test_sock_connect(self): + sock = unittest.mock.Mock() + self.loop._sock_connect = unittest.mock.Mock() + + f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, ('127.0.0.1', 8080)), + self.loop._sock_connect.call_args[0]) + + def test__sock_connect(self): + f = futures.Future(loop=self.loop) + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + self.assertTrue(sock.connect.called) + + def test__sock_connect_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.connect.called) + + def test__sock_connect_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) + + def test__sock_connect_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.EAGAIN + + self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() + + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual( + (10, self.loop._sock_connect, f, + True, sock, ('127.0.0.1', 8080)), + self.loop.add_writer.call_args[0]) + + def test__sock_connect_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.ENOTCONN + + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f.exception(), OSError) + + def test_sock_accept(self): + sock = unittest.mock.Mock() + self.loop._sock_accept = unittest.mock.Mock() + + f = self.loop.sock_accept(sock) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock), self.loop._sock_accept.call_args[0]) + + def test__sock_accept(self): + f = futures.Future(loop=self.loop) + + conn = unittest.mock.Mock() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.return_value = conn, ('127.0.0.1', 1000) + + self.loop._sock_accept(f, False, sock) + self.assertTrue(f.done()) + self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) + self.assertEqual((False,), conn.setblocking.call_args[0]) + + def test__sock_accept_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_accept(f, False, sock) + self.assertFalse(sock.accept.called) + + def test__sock_accept_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_reader = unittest.mock.Mock() + self.loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) + + def test__sock_accept_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = BlockingIOError + + self.loop.add_reader = unittest.mock.Mock() + self.loop._sock_accept(f, False, sock) + self.assertEqual( + (10, self.loop._sock_accept, f, True, sock), + self.loop.add_reader.call_args[0]) + + def test__sock_accept_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.accept.side_effect = OSError() + + self.loop._sock_accept(f, False, sock) + self.assertIs(err, f.exception()) + + def test_add_reader(self): + self.loop._selector.get_key.side_effect = KeyError + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertIsNone(w) + + def test_add_reader_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (reader, writer)) + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + def test_add_reader_existing_writer(self): + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, writer)) + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + def test_remove_reader(self): + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (None, None)) + self.assertFalse(self.loop.remove_reader(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + def test_remove_reader_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) + self.assertTrue( + self.loop.remove_reader(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.loop._selector.modify.call_args[0]) + + def test_remove_reader_unknown(self): + self.loop._selector.get_key.side_effect = KeyError + self.assertFalse( + self.loop.remove_reader(1)) + + def test_add_writer(self): + self.loop._selector.get_key.side_effect = KeyError + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE, mask) + self.assertIsNone(r) + self.assertEqual(cb, w._callback) + + def test_add_writer_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, writer)) + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(reader, r) + self.assertEqual(cb, w._callback) + + def test_remove_writer(self): + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, None)) + self.assertFalse(self.loop.remove_writer(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + def test_remove_writer_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) + self.assertTrue( + self.loop.remove_writer(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.loop._selector.modify.call_args[0]) + + def test_remove_writer_unknown(self): + self.loop._selector.get_key.side_effect = KeyError + self.assertFalse( + self.loop.remove_writer(1)) + + def test_process_events_read(self): + reader = unittest.mock.Mock() + reader._cancelled = False + + self.loop._add_callback = unittest.mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.assertTrue(self.loop._add_callback.called) + self.loop._add_callback.assert_called_with(reader) + + def test_process_events_read_cancelled(self): + reader = unittest.mock.Mock() + reader.cancelled = True + + self.loop.remove_reader = unittest.mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.loop.remove_reader.assert_called_with(1) + + def test_process_events_write(self): + writer = unittest.mock.Mock() + writer._cancelled = False + + self.loop._add_callback = unittest.mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop._add_callback.assert_called_with(writer) + + def test_process_events_write_cancelled(self): + writer = unittest.mock.Mock() + writer.cancelled = True + self.loop.remove_writer = unittest.mock.Mock() + + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop.remove_writer.assert_called_with(1) + + +class SelectorTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + + def test_ctor(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + self.assertIs(tr._loop, self.loop) + self.assertIs(tr._sock, self.sock) + self.assertIs(tr._sock_fd, 7) + + def test_abort(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = unittest.mock.Mock() + + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr.close() + + self.assertTrue(tr._closing) + self.assertEqual(1, self.loop.remove_reader_count[7]) + self.protocol.connection_lost(None) + self.assertEqual(tr._conn_lost, 1) + + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + def test_close_write_buffer(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._buffer.append(b'data') + tr.close() + + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_force_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._buffer.append(b'1') + self.loop.add_reader(7, unittest.mock.sentinel) + self.loop.add_writer(7, unittest.mock.sentinel) + tr._force_close(None) + + self.assertTrue(tr._closing) + self.assertEqual(tr._buffer, collections.deque()) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + + # second close should not remove reader + tr._force_close(None) + self.assertFalse(self.loop.readers) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + def test_fatal_error(self, m_exc): + exc = OSError() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(exc) + + m_exc.assert_called_with('Fatal error for %s', tr) + tr._force_close.assert_called_with(exc) + + def test_connection_lost(self): + exc = OSError() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._call_connection_lost(exc) + + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() + self.assertIsNone(tr._sock) + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class SelectorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) + self.sock = unittest.mock.Mock(socket.socket) + self.sock_fd = self.sock.fileno.return_value = 7 + + def test_ctor(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.loop.assert_reader(7, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + + _SelectorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + def test_pause_resume(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + tr.pause() + self.assertTrue(tr._paused) + self.assertFalse(7 in self.loop.readers) + tr.resume() + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + + def test_read_ready(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.protocol.data_received.assert_called_with(b'data') + + def test_read_ready_eof(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = unittest.mock.Mock() + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + transport.close.assert_called_with() + + def test_read_ready_eof_keep_open(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = unittest.mock.Mock() + + self.sock.recv.return_value = b'' + self.protocol.eof_received.return_value = True + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertFalse(transport.close.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_tryagain(self, m_exc): + self.sock.recv.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_tryagain_interrupted(self, m_exc): + self.sock.recv.side_effect = InterruptedError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_conn_reset(self, m_exc): + err = self.sock.recv.side_effect = ConnectionResetError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._force_close = unittest.mock.Mock() + transport._read_ready() + transport._force_close.assert_called_with(err) + + @unittest.mock.patch('logging.exception') + def test_read_ready_err(self, m_exc): + err = self.sock.recv.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with(err) + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + + def test_write_no_data(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_buffer(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual(collections.deque([b'data1', b'data2']), + transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'ta']), transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + self.sock.fileno.return_value = 7 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + data = b'data' + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_write_exception(self, m_log): + err = self.sock.send.side_effect = OSError() + + data = b'data' + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.write(data) + transport._fatal_error.assert_called_with(err) + transport._conn_lost = 1 + + self.sock.reset_mock() + transport.write(data) + self.assertFalse(self.sock.send.called) + self.assertEqual(transport._conn_lost, 2) + transport.write(data) + transport.write(data) + transport.write(data) + transport.write(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_write_str(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertFalse(self.loop.writers) + + def test_write_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.sock.send.assert_called_with(data) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test_write_ready_no_data(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport._write_ready) + + def test_write_ready_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'ta']), transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_ready_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer = collections.deque([b'data1', b'data2']) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data1data2']), transport._buffer) + + def test_write_ready_exception(self): + err = self.sock.send.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._write_ready() + transport._fatal_error.assert_called_with(err) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_write_ready_exception_and_close(self, m_log): + self.sock.send.side_effect = OSError() + remove_writer = self.loop.remove_writer = unittest.mock.Mock() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close() + transport._buffer.append(b'data') + transport._write_ready() + remove_writer.assert_called_with(self.sock_fd) + + def test_write_eof(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.sock.send.side_effect = BlockingIOError + tr.write(b'data') + tr.write_eof() + self.assertEqual(tr._buffer, collections.deque([b'data'])) + self.assertTrue(tr._eof) + self.assertFalse(self.sock.shutdown.called) + self.sock.send.side_effect = lambda _: 4 + tr._write_ready() + self.sock.send.assert_called_with(b'data') + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorSslTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.sslsock = unittest.mock.Mock() + self.sslsock.fileno.return_value = 1 + self.sslcontext = unittest.mock.Mock() + self.sslcontext.wrap_socket.return_value = self.sslsock + + def _make_one(self, create_waiter=None): + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + self.sock.reset_mock() + self.sslsock.reset_mock() + self.sslcontext.reset_mock() + self.loop.reset_counters() + return transport + + def test_on_handshake(self): + waiter = futures.Future(loop=self.loop) + tr = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, + waiter=waiter) + self.assertTrue(self.sslsock.do_handshake.called) + self.loop.assert_reader(1, tr._on_ready) + self.loop.assert_writer(1, tr._on_ready) + test_utils.run_briefly(self.loop) + self.assertIsNone(waiter.result()) + + def test_on_handshake_reader_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._on_handshake() + self.loop.assert_reader(1, transport._on_handshake) + + def test_on_handshake_writer_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._on_handshake() + self.loop.assert_writer(1, transport._on_handshake) + + def test_on_handshake_exc(self): + exc = ValueError() + self.sslsock.do_handshake.side_effect = exc + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._waiter = futures.Future(loop=self.loop) + transport._on_handshake() + self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) + + def test_on_handshake_base_exc(self): + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._waiter = futures.Future(loop=self.loop) + exc = BaseException() + self.sslsock.do_handshake.side_effect = exc + self.assertRaises(BaseException, transport._on_handshake) + self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) + + def test_pause_resume(self): + tr = self._make_one() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._on_ready) + tr.pause() + self.assertTrue(tr._paused) + self.assertFalse(1 in self.loop.readers) + tr.resume() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._on_ready) + + def test_write_no_data(self): + transport = self._make_one() + transport._buffer.append(b'data') + transport.write(b'') + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_str(self): + transport = self._make_one() + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = self._make_one() + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_write_exception(self, m_log): + transport = self._make_one() + transport._conn_lost = 1 + transport.write(b'data') + self.assertEqual(transport._buffer, collections.deque()) + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_on_ready_recv(self): + self.sslsock.recv.return_value = b'data' + transport = self._make_one() + transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) + + def test_on_ready_recv_eof(self): + self.sslsock.recv.return_value = b'' + transport = self._make_one() + transport.close = unittest.mock.Mock() + transport._on_ready() + transport.close.assert_called_with() + self.protocol.eof_received.assert_called_with() + + def test_on_ready_recv_conn_reset(self): + err = self.sslsock.recv.side_effect = ConnectionResetError() + transport = self._make_one() + transport._force_close = unittest.mock.Mock() + transport._on_ready() + transport._force_close.assert_called_with(err) + + def test_on_ready_recv_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + transport = self._make_one() + transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = BlockingIOError + transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = InterruptedError + transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_on_ready_recv_exc(self): + err = self.sslsock.recv.side_effect = OSError() + transport = self._make_one() + transport._fatal_error = unittest.mock.Mock() + transport._on_ready() + transport._fatal_error.assert_called_with(err) + + def test_on_ready_send(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + transport._on_ready() + self.assertEqual(collections.deque(), transport._buffer) + self.assertTrue(self.sslsock.send.called) + + def test_on_ready_send_none(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 0 + transport = self._make_one() + transport._buffer = collections.deque([b'data1', b'data2']) + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(collections.deque([b'data1data2']), transport._buffer) + + def test_on_ready_send_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + transport = self._make_one() + transport._buffer = collections.deque([b'data1', b'data2']) + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(collections.deque([b'ta1data2']), transport._buffer) + + def test_on_ready_send_closing_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + transport = self._make_one() + transport._buffer = collections.deque([b'data1', b'data2']) + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertFalse(self.sslsock.close.called) + + def test_on_ready_send_closing(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport.close() + transport._buffer = collections.deque([b'data']) + transport._on_ready() + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) + + def test_on_ready_send_closing_empty_buffer(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport.close() + transport._buffer = collections.deque() + transport._on_ready() + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) + + def test_on_ready_send_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + + self.sslsock.send.side_effect = ssl.SSLWantReadError + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + self.sslsock.send.side_effect = ssl.SSLWantWriteError + transport._on_ready() + self.assertEqual(collections.deque([b'data']), transport._buffer) + + self.sslsock.send.side_effect = BlockingIOError() + transport._on_ready() + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_on_ready_send_exc(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + err = self.sslsock.send.side_effect = OSError() + + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + transport._fatal_error = unittest.mock.Mock() + transport._on_ready() + transport._fatal_error.assert_called_with(err) + self.assertEqual(collections.deque(), transport._buffer) + + def test_write_eof(self): + tr = self._make_one() + self.assertFalse(tr.can_write_eof()) + self.assertRaises(NotImplementedError, tr.write_eof) + + def test_close(self): + tr = self._make_one() + tr.close() + + self.assertTrue(tr._closing) + self.assertEqual(1, self.loop.remove_reader_count[1]) + self.assertEqual(tr._conn_lost, 1) + + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertEqual(1, self.loop.remove_reader_count[1]) + + @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support') + def test_server_hostname(self): + _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, + server_hostname='localhost') + self.sslcontext.wrap_socket.assert_called_with( + self.sock, do_handshake_on_connect=False, server_side=False, + server_hostname='localhost') + + +class SelectorDatagramTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(DatagramProtocol) + self.sock = unittest.mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + + def test_read_ready(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.protocol.datagram_received.assert_called_with( + b'data', ('0.0.0.0', 1234)) + + def test_read_ready_tryagain(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + self.sock.recvfrom.side_effect = BlockingIOError + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + err = self.sock.recvfrom.side_effect = OSError() + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with(err) + + def test_sendto(self): + data = b'data' + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_no_data(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_tryagain(self): + data = b'data' + + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 12345)) + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_sendto_exception(self, m_log): + data = b'data' + err = self.sock.sendto.side_effect = OSError() + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + transport._fatal_error.assert_called_with(err) + transport._conn_lost = 1 + + transport._address = ('123',) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_sendto_connection_refused(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertEqual(transport._conn_lost, 0) + self.assertFalse(transport._fatal_error.called) + + def test_sendto_connection_refused_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data) + + self.assertTrue(transport._fatal_error.called) + + def test_sendto_str(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.assertRaises( + AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, address=(1,)) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.sendto(b'data', (1,)) + self.assertEqual(transport._conn_lost, 2) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((data, ('0.0.0.0', 12345))) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append((data, ())) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.sock.sendto.assert_called_with(data, ()) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test_sendto_ready_no_data(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_tryagain(self): + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + err = self.sock.sendto.side_effect = OSError() + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + transport._fatal_error.assert_called_with(err) + + def test_sendto_ready_connection_refused(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_connection_refused_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + def test_fatal_error_connected(self, m_exc): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.protocol.connection_refused.assert_called_with(err) + m_exc.assert_called_with('Fatal error for %s', transport) diff --git a/Lib/test/test_asyncio/test_selectors.py b/Lib/test/test_asyncio/test_selectors.py new file mode 100644 index 0000000..2f7dc69 --- /dev/null +++ b/Lib/test/test_asyncio/test_selectors.py @@ -0,0 +1,145 @@ +"""Tests for selectors.py.""" + +import unittest +import unittest.mock + +from asyncio import selectors + + +class FakeSelector(selectors.BaseSelector): + """Trivial non-abstract subclass of BaseSelector.""" + + def select(self, timeout=None): + raise NotImplementedError + + +class BaseSelectorTests(unittest.TestCase): + + def test_fileobj_to_fd(self): + self.assertEqual(10, selectors._fileobj_to_fd(10)) + + f = unittest.mock.Mock() + f.fileno.return_value = 10 + self.assertEqual(10, selectors._fileobj_to_fd(f)) + + f.fileno.side_effect = AttributeError + self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + + def test_selector_key_repr(self): + key = selectors.SelectorKey(10, 10, selectors.EVENT_READ, None) + self.assertEqual( + "SelectorKey(fileobj=10, fd=10, events=1, data=None)", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = FakeSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(KeyError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + + def test_unregister_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + self.assertRaises(KeyError, s.unregister, fobj) + + def test_modify_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + self.assertRaises(KeyError, s.modify, fobj, 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual( + selectors.SelectorKey(fobj, 10, selectors.EVENT_WRITE, None), + s.get_key(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual( + selectors.SelectorKey(fobj, 10, selectors.EVENT_READ, d2), + s.get_key(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = FakeSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = FakeSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + + def test_context_manager(self): + s = FakeSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + + def test_key_from_fd(self): + s = FakeSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) + + if hasattr(selectors.DefaultSelector, 'fileno'): + def test_fileno(self): + self.assertIsInstance(selectors.DefaultSelector().fileno(), int) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py new file mode 100644 index 0000000..011a09d --- /dev/null +++ b/Lib/test/test_asyncio/test_streams.py @@ -0,0 +1,361 @@ +"""Tests for streams.py.""" + +import gc +import ssl +import unittest +import unittest.mock + +from asyncio import events +from asyncio import streams +from asyncio import tasks +from asyncio import test_utils + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + @unittest.mock.patch('asyncio.streams.events') + def test_ctor_global_loop(self, m_events): + stream = streams.StreamReader() + self.assertIs(stream.loop, m_events.get_event_loop.return_value) + + def test_open_connection(self): + with test_utils.run_test_server() as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_open_connection_no_loop_ssl(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + try: + events.set_event_loop(self.loop) + f = streams.open_connection(*httpd.address, + ssl=test_utils.dummy_ssl_context()) + reader, writer = self.loop.run_until_complete(f) + finally: + events.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + def test_open_connection_error(self): + with test_utils.run_test_server() as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + + writer.close() + test_utils.run_briefly(self.loop) + + def test_feed_empty_data(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(30), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(1024), loop=self.loop) + + def cb(): + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(-1), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline(), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.loop.call_soon(cb) + + line = self.loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + stream = streams.StreamReader(7, loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + self.loop.run_until_complete(stream.readline()) + + data = self.loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader(loop=self.loop) + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader(loop=self.loop) + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def set_err(): + stream.set_exception(ValueError()) + + @tasks.coroutine + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + def test_exception_cancel(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def read_a_line(): + yield from stream.readline() + + t = tasks.Task(read_a_line(), loop=self.loop) + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + # The following line fails if set_exception() isn't careful. + stream.set_exception(RuntimeError('message')) + test_utils.run_briefly(self.loop) + self.assertIs(stream.waiter, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py new file mode 100644 index 0000000..57fb053 --- /dev/null +++ b/Lib/test/test_asyncio/test_tasks.py @@ -0,0 +1,1518 @@ +"""Tests for tasks.py.""" + +import gc +import unittest +import unittest.mock +from unittest.mock import Mock + +from asyncio import events +from asyncio import futures +from asyncio import tasks +from asyncio import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + gc.collect() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = events.new_event_loop() + t = tasks.Task(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_async_coroutine(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.async(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = events.new_event_loop() + t = tasks.async(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_async_future(self): + f_orig = futures.Future(loop=self.loop) + f_orig.set_result('ko') + + f = tasks.async(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + loop = events.new_event_loop() + + with self.assertRaises(ValueError): + f = tasks.async(f_orig, loop=loop) + + loop.close() + + f = tasks.async(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_async_task(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t_orig = tasks.Task(notmuch(), loop=self.loop) + t = tasks.async(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + loop = events.new_event_loop() + + with self.assertRaises(ValueError): + t = tasks.async(t_orig, loop=loop) + + loop.close() + + t = tasks.async(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_async_neither(self): + with self.assertRaises(TypeError): + tasks.async('ok') + + def test_task_repr(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'abc' + + t = tasks.Task(notmuch(), loop=self.loop) + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + gen = coro() + t = MyTask(gen, loop=self.loop) + self.assertEqual(repr(t), 'T[](<coro>)') + gen.close() + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(futures.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + return 42 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + res = yield from fut3 + return res + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = events.new_event_loop() + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from tasks.sleep(100, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + yield from fut + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + gen = notmuch() + task = tasks.Task(gen, loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + gen.close() + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + base_exc = BaseException() + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise base_exc + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + gen = coro() + try: + yield gen + finally: + gen.close() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except futures.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @tasks.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except futures.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + d, p = yield from tasks.wait([inner()], loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(tasks.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + yield from tasks.shield(inner(), loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + parent = tasks.gather(child1, child2, loop=self.loop) + outer = tasks.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + inner1 = tasks.shield(child1, loop=self.loop) + inner2 = tasks.shield(child2, loop=self.loop) + parent = tasks.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), futures.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + +class GatherTestsBase: + + def setUp(self): + self.one_loop = test_utils.TestLoop() + self.other_loop = test_utils.TestLoop() + + def tearDown(self): + self.one_loop.close() + self.other_loop.close() + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] + fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = Mock() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_return_exceptions(self): + a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + +class FutureGatherTests(GatherTestsBase, unittest.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + events.set_event_loop(self.one_loop) + self.addCleanup(events.set_event_loop, None) + fut = tasks.gather(*seq_or_iter) + self.assertIsInstance(fut, futures.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = tasks.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = futures.Future(loop=self.one_loop) + fut2 = futures.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + tasks.gather(fut1, fut2) + with self.assertRaises(ValueError): + tasks.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [futures.Future(loop=self.other_loop) for i in range(3)] + fut = tasks.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = tasks.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(a, b, c, d, e) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), futures.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [futures.Future(loop=self.one_loop) + for i in range(6)] + fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], futures.CancelledError) + self.assertIsInstance(res[4], futures.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): + + def setUp(self): + super().setUp() + events.set_event_loop(self.one_loop) + + def tearDown(self): + events.set_event_loop(None) + super().tearDown() + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + @tasks.coroutine + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @tasks.coroutine + def coro(): + return 'abc' + gen1 = coro() + gen2 = coro() + fut = tasks.gather(gen1, gen2) + self.assertIs(fut._loop, self.one_loop) + gen1.close() + gen2.close() + gen3 = coro() + gen4 = coro() + fut = tasks.gather(gen3, gen4, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + gen3.close() + gen4.close() + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = tasks.async(inner(), loop=self.one_loop) + child2 = tasks.async(inner(), loop=self.one_loop) + gatherer = None + + @tasks.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = tasks.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(futures.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @tasks.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = futures.Future(loop=self.one_loop) + b = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def outer(): + yield from tasks.gather(inner(a), inner(b), loop=self.one_loop) + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py new file mode 100644 index 0000000..fce2e6f --- /dev/null +++ b/Lib/test/test_asyncio/test_transports.py @@ -0,0 +1,55 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from asyncio import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py new file mode 100644 index 0000000..ea67862 --- /dev/null +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -0,0 +1,767 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import unittest +import unittest.mock + + +from asyncio import events +from asyncio import futures +from asyncio import protocols +from asyncio import test_utils +from asyncio import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + @unittest.mock.patch('os.waitpid') + def test__sig_chld_process_error(self, m_waitpid): + m_waitpid.side_effect = ChildProcessError + self.loop._sig_chld() + self.assertTrue(m_waitpid.called) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + # This is a bit overspecified. :-( + m_log.warning.assert_called_with( + 'pipe closed by peer or os.write(pipe, data) raised exception.') + + @unittest.mock.patch('os.write') + def test_write_close(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py new file mode 100644 index 0000000..4b04073 --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -0,0 +1,95 @@ +import os +import sys +import unittest + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import asyncio + +from asyncio import windows_events +from asyncio import protocols +from asyncio import streams +from asyncio import transports +from asyncio import test_utils + + +class UpperProto(protocols.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = asyncio.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + server2 = windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = streams.StreamReader(loop=self.loop) + protocol = streams.StreamReaderProtocol(stream_reader) + trans, proto = yield from self.loop.create_pipe_connection( + lambda:protocol, ADDRESS) + self.assertIsInstance(trans, transports.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + return 'done' diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py new file mode 100644 index 0000000..4b96086 --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_utils.py @@ -0,0 +1,136 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +from asyncio import windows_utils +from asyncio import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/Lib/test/test_asyncio/tests.txt b/Lib/test/test_asyncio/tests.txt new file mode 100644 index 0000000..e947721 --- /dev/null +++ b/Lib/test/test_asyncio/tests.txt @@ -0,0 +1,14 @@ +test_asyncio.test_base_events +test_asyncio.test_events +test_asyncio.test_futures +test_asyncio.test_locks +test_asyncio.test_proactor_events +test_asyncio.test_queues +test_asyncio.test_selector_events +test_asyncio.test_selectors +test_asyncio.test_streams +test_asyncio.test_tasks +test_asyncio.test_transports +test_asyncio.test_unix_events +test_asyncio.test_windows_events +test_asyncio.test_windows_utils |