summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio
diff options
context:
space:
mode:
authorGuido van Rossum <guido@dropbox.com>2013-10-17 20:40:50 (GMT)
committerGuido van Rossum <guido@dropbox.com>2013-10-17 20:40:50 (GMT)
commit27b7c7ebf1039e96cac41b6330cf16b5632d9e49 (patch)
tree814505b0f9d02a5cabdec733dcde70250b04ee28 /Lib/test/test_asyncio
parent5b37f97ea5ac9f6b33b0e0269c69539cbb478142 (diff)
downloadcpython-27b7c7ebf1039e96cac41b6330cf16b5632d9e49.zip
cpython-27b7c7ebf1039e96cac41b6330cf16b5632d9e49.tar.gz
cpython-27b7c7ebf1039e96cac41b6330cf16b5632d9e49.tar.bz2
Initial checkin of asyncio package (== Tulip, == PEP 3156).
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r--Lib/test/test_asyncio/__init__.py26
-rw-r--r--Lib/test/test_asyncio/__main__.py5
-rw-r--r--Lib/test/test_asyncio/echo.py6
-rw-r--r--Lib/test/test_asyncio/echo2.py6
-rw-r--r--Lib/test/test_asyncio/echo3.py9
-rw-r--r--Lib/test/test_asyncio/sample.crt14
-rw-r--r--Lib/test/test_asyncio/sample.key15
-rw-r--r--Lib/test/test_asyncio/test_base_events.py590
-rw-r--r--Lib/test/test_asyncio/test_events.py1573
-rw-r--r--Lib/test/test_asyncio/test_futures.py329
-rw-r--r--Lib/test/test_asyncio/test_locks.py765
-rw-r--r--Lib/test/test_asyncio/test_proactor_events.py480
-rw-r--r--Lib/test/test_asyncio/test_queues.py470
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py1485
-rw-r--r--Lib/test/test_asyncio/test_selectors.py145
-rw-r--r--Lib/test/test_asyncio/test_streams.py361
-rw-r--r--Lib/test/test_asyncio/test_tasks.py1518
-rw-r--r--Lib/test/test_asyncio/test_transports.py55
-rw-r--r--Lib/test/test_asyncio/test_unix_events.py767
-rw-r--r--Lib/test/test_asyncio/test_windows_events.py95
-rw-r--r--Lib/test/test_asyncio/test_windows_utils.py136
-rw-r--r--Lib/test/test_asyncio/tests.txt14
22 files changed, 8864 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py
new file mode 100644
index 0000000..ec483ea
--- /dev/null
+++ b/Lib/test/test_asyncio/__init__.py
@@ -0,0 +1,26 @@
+import os
+import sys
+import unittest
+from test.support import run_unittest
+
+
+def suite():
+ tests_file = os.path.join(os.path.dirname(__file__), 'tests.txt')
+ with open(tests_file) as fp:
+ test_names = fp.read().splitlines()
+ tests = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ for test_name in test_names:
+ mod_name = 'test.' + test_name
+ try:
+ __import__(mod_name)
+ except unittest.SkipTest:
+ pass
+ else:
+ mod = sys.modules[mod_name]
+ tests.addTests(loader.loadTestsFromModule(mod))
+ return tests
+
+
+def test_main():
+ run_unittest(suite())
diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py
new file mode 100644
index 0000000..b549492
--- /dev/null
+++ b/Lib/test/test_asyncio/__main__.py
@@ -0,0 +1,5 @@
+from . import test_main
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py
new file mode 100644
index 0000000..f6ac0a3
--- /dev/null
+++ b/Lib/test/test_asyncio/echo.py
@@ -0,0 +1,6 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ os.write(1, buf)
diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py
new file mode 100644
index 0000000..e83ca09
--- /dev/null
+++ b/Lib/test/test_asyncio/echo2.py
@@ -0,0 +1,6 @@
+import os
+
+if __name__ == '__main__':
+ buf = os.read(0, 1024)
+ os.write(1, b'OUT:'+buf)
+ os.write(2, b'ERR:'+buf)
diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py
new file mode 100644
index 0000000..f1f7ea7
--- /dev/null
+++ b/Lib/test/test_asyncio/echo3.py
@@ -0,0 +1,9 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ try:
+ os.write(1, b'OUT:'+buf)
+ except OSError as ex:
+ os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii'))
diff --git a/Lib/test/test_asyncio/sample.crt b/Lib/test/test_asyncio/sample.crt
new file mode 100644
index 0000000..6a1e3f3
--- /dev/null
+++ b/Lib/test/test_asyncio/sample.crt
@@ -0,0 +1,14 @@
+-----BEGIN CERTIFICATE-----
+MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV
+UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi
+MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3
+MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp
+Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g
+U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn
+t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7
+gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg
+Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB
+BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB
+MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I
+I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB
+-----END CERTIFICATE-----
diff --git a/Lib/test/test_asyncio/sample.key b/Lib/test/test_asyncio/sample.key
new file mode 100644
index 0000000..edfea8d
--- /dev/null
+++ b/Lib/test/test_asyncio/sample.key
@@ -0,0 +1,15 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx
+/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi
+qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB
+AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y
+bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1
+iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb
+DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP
+lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL
+21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF
+ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0
+zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u
+GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq
+V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof
+-----END RSA PRIVATE KEY-----
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
new file mode 100644
index 0000000..d48d12c
--- /dev/null
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -0,0 +1,590 @@
+"""Tests for base_events.py"""
+
+import logging
+import socket
+import time
+import unittest
+import unittest.mock
+
+from asyncio import base_events
+from asyncio import events
+from asyncio import futures
+from asyncio import protocols
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class BaseEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = base_events.BaseEventLoop()
+ self.loop._selector = unittest.mock.Mock()
+ events.set_event_loop(None)
+
+ def test_not_implemented(self):
+ m = unittest.mock.Mock()
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_socket_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_ssl_transport, m, m, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_datagram_transport, m, m)
+ self.assertRaises(
+ NotImplementedError, self.loop._process_events, [])
+ self.assertRaises(
+ NotImplementedError, self.loop._write_to_self)
+ self.assertRaises(
+ NotImplementedError, self.loop._read_from_self)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_read_pipe_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_write_pipe_transport, m, m)
+ gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m)
+ self.assertRaises(NotImplementedError, next, iter(gen))
+
+ def test__add_callback_handle(self):
+ h = events.Handle(lambda: False, ())
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertIn(h, self.loop._ready)
+
+ def test__add_callback_timer(self):
+ h = events.TimerHandle(time.monotonic()+10, lambda: False, ())
+
+ self.loop._add_callback(h)
+ self.assertIn(h, self.loop._scheduled)
+
+ def test__add_callback_cancelled_handle(self):
+ h = events.Handle(lambda: False, ())
+ h.cancel()
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertFalse(self.loop._ready)
+
+ def test_set_default_executor(self):
+ executor = unittest.mock.Mock()
+ self.loop.set_default_executor(executor)
+ self.assertIs(executor, self.loop._default_executor)
+
+ def test_getnameinfo(self):
+ sockaddr = unittest.mock.Mock()
+ self.loop.run_in_executor = unittest.mock.Mock()
+ self.loop.getnameinfo(sockaddr)
+ self.assertEqual(
+ (None, socket.getnameinfo, sockaddr, 0),
+ self.loop.run_in_executor.call_args[0])
+
+ def test_call_soon(self):
+ def cb():
+ pass
+
+ h = self.loop.call_soon(cb)
+ self.assertEqual(h._callback, cb)
+ self.assertIsInstance(h, events.Handle)
+ self.assertIn(h, self.loop._ready)
+
+ def test_call_later(self):
+ def cb():
+ pass
+
+ h = self.loop.call_later(10.0, cb)
+ self.assertIsInstance(h, events.TimerHandle)
+ self.assertIn(h, self.loop._scheduled)
+ self.assertNotIn(h, self.loop._ready)
+
+ def test_call_later_negative_delays(self):
+ calls = []
+
+ def cb(arg):
+ calls.append(arg)
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop.call_later(-1, cb, 'a')
+ self.loop.call_later(-2, cb, 'b')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(calls, ['b', 'a'])
+
+ def test_time_and_call_at(self):
+ def cb():
+ self.loop.stop()
+
+ self.loop._process_events = unittest.mock.Mock()
+ when = self.loop.time() + 0.1
+ self.loop.call_at(when, cb)
+ t0 = self.loop.time()
+ self.loop.run_forever()
+ t1 = self.loop.time()
+ self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0)
+
+ def test_run_once_in_executor_handle(self):
+ def cb():
+ pass
+
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, events.Handle(cb, ()), ('',))
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, events.TimerHandle(10, cb, ()))
+
+ def test_run_once_in_executor_cancelled(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ h.cancel()
+
+ f = self.loop.run_in_executor(None, h)
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test_run_once_in_executor_plain(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ f = futures.Future(loop=self.loop)
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+
+ self.loop.set_default_executor(executor)
+
+ res = self.loop.run_in_executor(None, h)
+ self.assertIs(f, res)
+
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+ res = self.loop.run_in_executor(executor, h)
+ self.assertIs(f, res)
+ self.assertTrue(executor.submit.called)
+
+ f.cancel() # Don't complain about abandoned Future.
+
+ def test__run_once(self):
+ h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ())
+ h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ())
+
+ h1.cancel()
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._scheduled.append(h1)
+ self.loop._scheduled.append(h2)
+ self.loop._run_once()
+
+ t = self.loop._selector.select.call_args[0][0]
+ self.assertTrue(9.99 < t < 10.1, t)
+ self.assertEqual([h2], self.loop._scheduled)
+ self.assertTrue(self.loop._process_events.called)
+
+ @unittest.mock.patch('asyncio.base_events.time')
+ @unittest.mock.patch('asyncio.base_events.asyncio_log')
+ def test__run_once_logging(self, m_logging, m_time):
+ # Log to INFO level if timeout > 1.0 sec.
+ idx = -1
+ data = [10.0, 10.0, 12.0, 13.0]
+
+ def monotonic():
+ nonlocal data, idx
+ idx += 1
+ return data[idx]
+
+ m_time.monotonic = monotonic
+ m_logging.INFO = logging.INFO
+ m_logging.DEBUG = logging.DEBUG
+
+ self.loop._scheduled.append(
+ events.TimerHandle(11.0, lambda: True, ()))
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._run_once()
+ self.assertEqual(logging.INFO, m_logging.log.call_args[0][0])
+
+ idx = -1
+ data = [10.0, 10.0, 10.3, 13.0]
+ self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())]
+ self.loop._run_once()
+ self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0])
+
+ def test__run_once_schedule_handle(self):
+ handle = None
+ processed = False
+
+ def cb(loop):
+ nonlocal processed, handle
+ processed = True
+ handle = loop.call_soon(lambda: True)
+
+ h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,))
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._scheduled.append(h)
+ self.loop._run_once()
+
+ self.assertTrue(processed)
+ self.assertEqual([handle], list(self.loop._ready))
+
+ def test_run_until_complete_type_error(self):
+ self.assertRaises(
+ TypeError, self.loop.run_until_complete, 'blah')
+
+
+class MyProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, create_future=False):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = futures.Future()
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyDatagramProto(protocols.DatagramProtocol):
+ done = None
+
+ def __init__(self, create_future=False):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = futures.Future()
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def connection_refused(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class BaseEventLoopWithSelectorTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = events.new_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_connection_multiple_errors(self, m_socket):
+
+ class MyProto(protocols.Protocol):
+ pass
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ idx = -1
+ errors = ['err1', 'err2']
+
+ def _socket(*args, **kw):
+ nonlocal idx, errors
+ idx += 1
+ raise OSError(errors[idx])
+
+ m_socket.socket = _socket
+
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
+
+ def test_create_connection_host_port_sock(self):
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_host_port_sock(self):
+ coro = self.loop.create_connection(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_getaddrinfo(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_connect_err(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_multiple(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET)
+ with self.assertRaises(OSError):
+ self.loop.run_until_complete(coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_connection_multiple_errors_local_addr(self, m_socket):
+
+ def bind(addr):
+ if addr[0] == '0.0.0.1':
+ err = OSError('Err')
+ err.strerror = 'Err'
+ raise err
+
+ m_socket.socket.return_value.bind = bind
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError('Err2')
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
+ self.assertTrue(m_socket.socket.return_value.close.called)
+
+ def test_create_connection_no_local_addr(self):
+ @tasks.coroutine
+ def getaddrinfo(host, *args, **kw):
+ if host == 'example.com':
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+ else:
+ return []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_server_empty_host(self):
+ # if host is empty string use None instead
+ host = object()
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ nonlocal host
+ host = args[0]
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ fut = self.loop.create_server(MyProto, '', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertIsNone(host)
+
+ def test_create_server_host_port_sock(self):
+ fut = self.loop.create_server(
+ MyProto, '0.0.0.0', 0, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_host_port_sock(self):
+ fut = self.loop.create_server(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_getaddrinfo(self):
+ getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock()
+ getaddrinfo.return_value = []
+
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, f)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_server_cant_bind(self, m_socket):
+
+ class Err(OSError):
+ strerror = 'error'
+
+ m_socket.getaddrinfo.return_value = [
+ (2, 1, 6, '', ('127.0.0.1', 10100))]
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
+ m_socket.getaddrinfo.return_value = []
+
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_addr_error(self):
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr='localhost')
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 1, 2, 3))
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_connect_err(self):
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_socket_err(self, m_socket):
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_socket.socket.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, local_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_no_matching_family(self):
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol,
+ remote_addr=('127.0.0.1', 0), local_addr=('::1', 0))
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_setblk_err(self, m_socket):
+ m_socket.socket.return_value.setblocking.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+ self.assertTrue(
+ m_socket.socket.return_value.close.called)
+
+ def test_create_datagram_endpoint_noaddr_nofamily(self):
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_cant_bind(self, m_socket):
+ class Err(OSError):
+ pass
+
+ m_socket.AF_INET6 = socket.AF_INET6
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto,
+ local_addr=('127.0.0.1', 0), family=socket.AF_INET)
+ self.assertRaises(Err, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ def test_accept_connection_retry(self):
+ sock = unittest.mock.Mock()
+ sock.accept.side_effect = BlockingIOError()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertFalse(sock.close.called)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_accept_connection_exception(self, m_log):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = OSError()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertTrue(sock.close.called)
+ self.assertTrue(m_log.exception.called)
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
new file mode 100644
index 0000000..243f400
--- /dev/null
+++ b/Lib/test/test_asyncio/test_events.py
@@ -0,0 +1,1573 @@
+"""Tests for events.py."""
+
+import functools
+import gc
+import io
+import os
+import signal
+import socket
+try:
+ import ssl
+except ImportError:
+ ssl = None
+import subprocess
+import sys
+import threading
+import time
+import errno
+import unittest
+import unittest.mock
+from test.support import find_unused_port
+
+
+from asyncio import futures
+from asyncio import events
+from asyncio import transports
+from asyncio import protocols
+from asyncio import selector_events
+from asyncio import tasks
+from asyncio import test_utils
+from asyncio import locks
+
+
+class MyProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyDatagramProto(protocols.DatagramProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def connection_refused(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyReadPipeProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = ['INITIAL']
+ self.nbytes = 0
+ self.transport = None
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == ['INITIAL'], self.state
+ self.state.append('CONNECTED')
+
+ def data_received(self, data):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.state.append('EOF')
+
+ def connection_lost(self, exc):
+ assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state
+ self.state.append('CLOSED')
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyWritePipeProto(protocols.BaseProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.transport = None
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MySubprocessProtocol(protocols.SubprocessProtocol):
+
+ def __init__(self, loop):
+ self.state = 'INITIAL'
+ self.transport = None
+ self.connected = futures.Future(loop=loop)
+ self.completed = futures.Future(loop=loop)
+ self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)}
+ self.data = {1: b'', 2: b''}
+ self.returncode = None
+ self.got_data = {1: locks.Event(loop=loop),
+ 2: locks.Event(loop=loop)}
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ self.connected.set_result(None)
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ self.completed.set_result(None)
+
+ def pipe_data_received(self, fd, data):
+ assert self.state == 'CONNECTED', self.state
+ self.data[fd] += data
+ self.got_data[fd].set()
+
+ def pipe_connection_lost(self, fd, exc):
+ assert self.state == 'CONNECTED', self.state
+ if exc:
+ self.disconnects[fd].set_exception(exc)
+ else:
+ self.disconnects[fd].set_result(exc)
+
+ def process_exited(self):
+ assert self.state == 'CONNECTED', self.state
+ self.returncode = self.transport.get_returncode()
+
+
+class EventLoopTestsMixin:
+
+ def setUp(self):
+ super().setUp()
+ self.loop = self.create_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+ super().tearDown()
+
+ def test_run_until_complete_nesting(self):
+ @tasks.coroutine
+ def coro1():
+ yield
+
+ @tasks.coroutine
+ def coro2():
+ self.assertTrue(self.loop.is_running())
+ self.loop.run_until_complete(coro1())
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, coro2())
+
+ # Note: because of the default Windows timing granularity of
+ # 15.6 msec, we use fairly long sleep times here (~100 msec).
+
+ def test_run_until_complete(self):
+ t0 = self.loop.time()
+ self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop))
+ t1 = self.loop.time()
+ self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0)
+
+ def test_run_until_complete_stopped(self):
+ @tasks.coroutine
+ def cb():
+ self.loop.stop()
+ yield from tasks.sleep(0.1, loop=self.loop)
+ task = cb()
+ self.assertRaises(RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_call_later(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ self.loop.stop()
+
+ self.loop.call_later(0.1, callback, 'hello world')
+ t0 = time.monotonic()
+ self.loop.run_forever()
+ t1 = time.monotonic()
+ self.assertEqual(results, ['hello world'])
+ self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0)
+
+ def test_call_soon(self):
+ results = []
+
+ def callback(arg1, arg2):
+ results.append((arg1, arg2))
+ self.loop.stop()
+
+ self.loop.call_soon(callback, 'hello', 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, [('hello', 'world')])
+
+ def test_call_soon_threadsafe(self):
+ results = []
+ lock = threading.Lock()
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ def run_in_thread():
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ lock.release()
+
+ lock.acquire()
+ t = threading.Thread(target=run_in_thread)
+ t.start()
+
+ with lock:
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ t.join()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_call_soon_threadsafe_same_thread(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_run_in_executor(self):
+ def run(arg):
+ return (arg, threading.get_ident())
+ f2 = self.loop.run_in_executor(None, run, 'yo')
+ res, thread_id = self.loop.run_until_complete(f2)
+ self.assertEqual(res, 'yo')
+ self.assertNotEqual(thread_id, threading.get_ident())
+
+ def test_reader_callback(self):
+ r, w = test_utils.socketpair()
+ bytes_read = []
+
+ def reader():
+ try:
+ data = r.recv(1024)
+ except BlockingIOError:
+ # Spurious readiness notifications are possible
+ # at least on Linux -- see man select.
+ return
+ if data:
+ bytes_read.append(data)
+ else:
+ self.assertTrue(self.loop.remove_reader(r.fileno()))
+ r.close()
+
+ self.loop.add_reader(r.fileno(), reader)
+ self.loop.call_soon(w.send, b'abc')
+ test_utils.run_briefly(self.loop)
+ self.loop.call_soon(w.send, b'def')
+ test_utils.run_briefly(self.loop)
+ self.loop.call_soon(w.close)
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(b''.join(bytes_read), b'abcdef')
+
+ def test_writer_callback(self):
+ r, w = test_utils.socketpair()
+ w.setblocking(False)
+ self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024))
+ test_utils.run_briefly(self.loop)
+
+ def remove_writer():
+ self.assertTrue(self.loop.remove_writer(w.fileno()))
+
+ self.loop.call_soon(remove_writer)
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ w.close()
+ data = r.recv(256*1024)
+ r.close()
+ self.assertGreaterEqual(len(data), 200)
+
+ def test_sock_client_ops(self):
+ with test_utils.run_test_server() as httpd:
+ sock = socket.socket()
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, httpd.address))
+ self.loop.run_until_complete(
+ self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ data = self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ # consume data
+ self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ sock.close()
+
+ self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+
+ def test_sock_client_fail(self):
+ # Make sure that we will get an unused port
+ address = None
+ try:
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ address = s.getsockname()
+ finally:
+ s.close()
+
+ sock = socket.socket()
+ sock.setblocking(False)
+ with self.assertRaises(ConnectionRefusedError):
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ sock.close()
+
+ def test_sock_accept(self):
+ listener = socket.socket()
+ listener.setblocking(False)
+ listener.bind(('127.0.0.1', 0))
+ listener.listen(1)
+ client = socket.socket()
+ client.connect(listener.getsockname())
+
+ f = self.loop.sock_accept(listener)
+ conn, addr = self.loop.run_until_complete(f)
+ self.assertEqual(conn.gettimeout(), 0)
+ self.assertEqual(addr, client.getsockname())
+ self.assertEqual(client.getpeername(), listener.getsockname())
+ client.close()
+ conn.close()
+ listener.close()
+
+ @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL')
+ def test_add_signal_handler(self):
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+
+ # Check error behavior first.
+ self.assertRaises(
+ TypeError, self.loop.add_signal_handler, 'boom', my_handler)
+ self.assertRaises(
+ TypeError, self.loop.remove_signal_handler, 'boom')
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, signal.NSIG+1,
+ my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, signal.NSIG+1)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, 0, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, 0)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, -1, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, -1)
+ self.assertRaises(
+ RuntimeError, self.loop.add_signal_handler, signal.SIGKILL,
+ my_handler)
+ # Removing SIGKILL doesn't raise, since we don't call signal().
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL))
+ # Now set a handler and handle it.
+ self.loop.add_signal_handler(signal.SIGINT, my_handler)
+ test_utils.run_briefly(self.loop)
+ os.kill(os.getpid(), signal.SIGINT)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(caught, 1)
+ # Removing it should restore the default handler.
+ self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertEqual(signal.getsignal(signal.SIGINT),
+ signal.default_int_handler)
+ # Removing again returns False.
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT))
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_while_selecting(self):
+ # Test with a signal actually arriving during a select() call.
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+ self.loop.stop()
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once.
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_args(self):
+ some_args = (42,)
+ caught = 0
+
+ def my_handler(*args):
+ nonlocal caught
+ caught += 1
+ self.assertEqual(args, some_args)
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once.
+ self.loop.call_later(0.015, self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ def test_create_connection(self):
+ with test_utils.run_test_server() as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), *httpd.address)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def test_create_connection_sock(self):
+ with test_utils.run_test_server() as httpd:
+ sock = None
+ infos = self.loop.run_until_complete(
+ self.loop.getaddrinfo(
+ *httpd.address, type=socket.SOCK_STREAM))
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ except:
+ pass
+ else:
+ break
+ else:
+ assert False, 'Can not create socket.'
+
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), sock=sock)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_ssl_connection(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), *httpd.address,
+ ssl=test_utils.dummy_ssl_context())
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.assertTrue('ssl' in tr.__class__.__name__.lower())
+ self.assertIsNotNone(tr.get_extra_info('sockname'))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def test_create_connection_local_addr(self):
+ with test_utils.run_test_server() as httpd:
+ port = find_unused_port()
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=(httpd.address[0], port))
+ tr, pr = self.loop.run_until_complete(f)
+ expected = pr.transport.get_extra_info('sockname')[1]
+ self.assertEqual(port, expected)
+ tr.close()
+
+ def test_create_connection_local_addr_in_use(self):
+ with test_utils.run_test_server() as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=httpd.address)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+ self.assertIn(str(httpd.address), cm.exception.strerror)
+
+ def test_create_server(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyProto()
+ return proto
+
+ f = self.loop.create_server(factory, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(proto, MyProto)
+ self.assertEqual('INITIAL', proto.state)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_briefly(self.loop) # windows iocp
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
+ self.assertEqual('127.0.0.1',
+ proto.transport.get_extra_info('peername')[0])
+
+ # close connection
+ proto.transport.close()
+ test_utils.run_briefly(self.loop) # windows iocp
+
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # close server
+ server.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_server_ssl(self):
+ proto = None
+
+ class ClientMyProto(MyProto):
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def factory():
+ nonlocal proto
+ proto = MyProto(loop=self.loop)
+ return proto
+
+ here = os.path.dirname(__file__)
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext.load_cert_chain(
+ certfile=os.path.join(here, 'sample.crt'),
+ keyfile=os.path.join(here, 'sample.key'))
+
+ f = self.loop.create_server(
+ factory, '127.0.0.1', 0, ssl=sslcontext)
+
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '127.0.0.1')
+
+ f_c = self.loop.create_connection(ClientMyProto, host, port,
+ ssl=test_utils.dummy_ssl_context())
+ client, pr = self.loop.run_until_complete(f_c)
+
+ client.write(b'xxx')
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(proto, MyProto)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
+ self.assertEqual('127.0.0.1',
+ proto.transport.get_extra_info('peername')[0])
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # stop serving
+ server.close()
+
+ def test_create_server_sock(self):
+ proto = futures.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ proto.set_result(self)
+
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(TestMyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ self.assertIs(sock, sock_ob)
+
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+ server.close()
+
+ def test_create_server_addr_in_use(self):
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(MyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ f = self.loop.create_server(MyProto, host=host, port=port)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+
+ server.close()
+
+ @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported')
+ def test_create_server_dual_stack(self):
+ f_proto = futures.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ f_proto.set_result(self)
+
+ try_count = 0
+ while True:
+ try:
+ port = find_unused_port()
+ f = self.loop.create_server(TestMyProto, host=None, port=port)
+ server = self.loop.run_until_complete(f)
+ except OSError as ex:
+ if ex.errno == errno.EADDRINUSE:
+ try_count += 1
+ self.assertGreaterEqual(5, try_count)
+ continue
+ else:
+ raise
+ else:
+ break
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ f_proto = futures.Future(loop=self.loop)
+ client = socket.socket(socket.AF_INET6)
+ client.connect(('::1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ server.close()
+
+ def test_server_close(self):
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+
+ server.close()
+
+ client = socket.socket()
+ self.assertRaises(
+ ConnectionRefusedError, client.connect, ('127.0.0.1', port))
+ client.close()
+
+ def test_create_datagram_endpoint(self):
+ class TestMyDatagramProto(MyDatagramProto):
+ def __init__(inner_self):
+ super().__init__(loop=self.loop)
+
+ def datagram_received(self, data, addr):
+ super().datagram_received(data, addr)
+ self.transport.sendto(b'resp:'+data, addr)
+
+ coro = self.loop.create_datagram_endpoint(
+ TestMyDatagramProto, local_addr=('127.0.0.1', 0))
+ s_transport, server = self.loop.run_until_complete(coro)
+ host, port = s_transport.get_extra_info('sockname')
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ remote_addr=(host, port))
+ transport, client = self.loop.run_until_complete(coro)
+
+ self.assertEqual('INITIALIZED', client.state)
+ transport.sendto(b'xxx')
+ for _ in range(1000):
+ if server.nbytes:
+ break
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(3, server.nbytes)
+ for _ in range(1000):
+ if client.nbytes:
+ break
+ test_utils.run_briefly(self.loop)
+
+ # received
+ self.assertEqual(8, client.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(transport.get_extra_info('sockname'))
+
+ # close connection
+ transport.close()
+ self.loop.run_until_complete(client.done)
+ self.assertEqual('CLOSED', client.state)
+ server.transport.close()
+
+ def test_internal_fds(self):
+ loop = self.create_event_loop()
+ if not isinstance(loop, selector_events.BaseSelectorEventLoop):
+ return
+
+ self.assertEqual(1, loop._internal_fds)
+ loop.close()
+ self.assertEqual(0, loop._internal_fds)
+ self.assertIsNone(loop._csock)
+ self.assertIsNone(loop._ssock)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_read_pipe(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyReadPipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(rpipe, 'rb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(0, proto.nbytes)
+
+ self.loop.run_until_complete(connect())
+
+ os.write(wpipe, b'1')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(1, proto.nbytes)
+
+ os.write(wpipe, b'2345')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(5, proto.nbytes)
+
+ os.close(wpipe)
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual(
+ ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe(self):
+ proto = None
+ transport = None
+
+ def factory():
+ nonlocal proto
+ proto = MyWritePipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ nonlocal transport
+ t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+ transport = t
+
+ self.loop.run_until_complete(connect())
+
+ transport.write(b'1')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'1', data)
+
+ transport.write(b'2345')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'2345', data)
+ self.assertEqual('CONNECTED', proto.state)
+
+ os.close(rpipe)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe_disconnect_on_close(self):
+ proto = None
+ transport = None
+
+ def factory():
+ nonlocal proto
+ proto = MyWritePipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ nonlocal transport
+ t, p = yield from self.loop.connect_write_pipe(factory,
+ pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+ transport = t
+
+ self.loop.run_until_complete(connect())
+ self.assertEqual('CONNECTED', proto.state)
+
+ transport.write(b'1')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'1', data)
+
+ os.close(rpipe)
+
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ def test_prompt_cancellation(self):
+ r, w = test_utils.socketpair()
+ r.setblocking(False)
+ f = self.loop.sock_recv(r, 1)
+ ov = getattr(f, 'ov', None)
+ self.assertTrue(ov is None or ov.pending)
+
+ @tasks.coroutine
+ def main():
+ try:
+ self.loop.call_soon(f.cancel)
+ yield from f
+ except futures.CancelledError:
+ res = 'cancelled'
+ else:
+ res = None
+ finally:
+ self.loop.stop()
+ return res
+
+ start = time.monotonic()
+ t = tasks.Task(main(), loop=self.loop)
+ self.loop.run_forever()
+ elapsed = time.monotonic() - start
+
+ self.assertLess(elapsed, 0.1)
+ self.assertEqual(t.result(), 'cancelled')
+ self.assertRaises(futures.CancelledError, f.result)
+ self.assertTrue(ov is None or not ov.pending)
+ self.loop._stop_serving(r)
+
+ r.close()
+ w.close()
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_exec(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGTERM, proto.returncode)
+ self.assertEqual(b'Python The Winner', proto.data[1])
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_interactive(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ try:
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python ')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ proto.got_data[1].clear()
+ self.assertEqual(b'Python ', proto.data[1])
+
+ stdin.write(b'The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'Python The Winner', proto.data[1])
+ finally:
+ transp.close()
+
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGTERM, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_shell(self):
+ proto = None
+ transp = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'echo "Python"')
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.get_pipe_transport(0).close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(0, proto.returncode)
+ self.assertTrue(all(f.done() for f in proto.disconnects.values()))
+ self.assertEqual({1: b'Python\n', 2: b''}, proto.data)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_exitcode(self):
+ proto = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_close_after_finish(self):
+ proto = None
+ transp = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.assertIsNone(transp.get_pipe_transport(0))
+ self.assertIsNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+ self.assertIsNone(transp.close())
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_kill(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.kill()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGKILL, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_send_signal(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.send_signal(signal.SIGHUP)
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGHUP, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_stderr(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'test')
+
+ self.loop.run_until_complete(proto.completed)
+
+ transp.close()
+ self.assertEqual(b'OUT:test', proto.data[1])
+ self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2])
+ self.assertEqual(0, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_stderr_redirect_to_stdout(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog, stderr=subprocess.STDOUT)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ self.assertIsNotNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.completed)
+ self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'),
+ proto.data[1])
+ self.assertEqual(b'', proto.data[2])
+
+ transp.close()
+ self.assertEqual(0, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_close_client_stream(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo3.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdout = transp.get_pipe_transport(1)
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'OUT:test', proto.data[1])
+
+ stdout.close()
+ self.loop.run_until_complete(proto.disconnects[1])
+ stdin.write(b'xxx')
+ self.loop.run_until_complete(proto.got_data[2].wait())
+ self.assertEqual(b'ERR:BrokenPipeError', proto.data[2])
+
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGTERM, proto.returncode)
+
+
+if sys.platform == 'win32':
+ from asyncio import windows_events
+
+ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return windows_events.SelectorEventLoop()
+
+ class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return windows_events.ProactorEventLoop()
+
+ def test_create_ssl_connection(self):
+ raise unittest.SkipTest("IocpEventLoop imcompatible with SSL")
+
+ def test_create_server_ssl(self):
+ raise unittest.SkipTest("IocpEventLoop imcompatible with SSL")
+
+ def test_reader_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_reader_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_writer_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_writer_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_create_datagram_endpoint(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+else:
+ from asyncio import selectors
+ from asyncio import unix_events
+
+ if hasattr(selectors, 'KqueueSelector'):
+ class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(
+ selectors.KqueueSelector())
+
+ if hasattr(selectors, 'EpollSelector'):
+ class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.EpollSelector())
+
+ if hasattr(selectors, 'PollSelector'):
+ class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.PollSelector())
+
+ # Should always exist.
+ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.SelectSelector())
+
+
+class HandleTests(unittest.TestCase):
+
+ def test_handle(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ h = events.Handle(callback, args)
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())<cancelled>'), r)
+
+ def test_make_handle(self):
+ def callback(*args):
+ return args
+ h1 = events.Handle(callback, ())
+ self.assertRaises(
+ AssertionError, events.make_handle, h1, ())
+
+ @unittest.mock.patch('asyncio.events.asyncio_log')
+ def test_callback_with_exception(self, log):
+ def callback():
+ raise ValueError()
+
+ h = events.Handle(callback, ())
+ h._run()
+ self.assertTrue(log.exception.called)
+
+
+class TimerTests(unittest.TestCase):
+
+ def test_hash(self):
+ when = time.monotonic()
+ h = events.TimerHandle(when, lambda: False, ())
+ self.assertEqual(hash(h), hash(when))
+
+ def test_timer(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ when = time.monotonic()
+ h = events.TimerHandle(when, callback, args)
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())<cancelled>'), r)
+
+ self.assertRaises(AssertionError,
+ events.TimerHandle, None, callback, args)
+
+ def test_timer_comparison(self):
+ def callback(*args):
+ return args
+
+ when = time.monotonic()
+
+ h1 = events.TimerHandle(when, callback, ())
+ h2 = events.TimerHandle(when, callback, ())
+ # TODO: Use assertLess etc.
+ self.assertFalse(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertTrue(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertFalse(h2 > h1)
+ self.assertTrue(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertTrue(h1 == h2)
+ self.assertFalse(h1 != h2)
+
+ h2.cancel()
+ self.assertFalse(h1 == h2)
+
+ h1 = events.TimerHandle(when, callback, ())
+ h2 = events.TimerHandle(when + 10.0, callback, ())
+ self.assertTrue(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertFalse(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertTrue(h2 > h1)
+ self.assertFalse(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h3 = events.Handle(callback, ())
+ self.assertIs(NotImplemented, h1.__eq__(h3))
+ self.assertIs(NotImplemented, h1.__ne__(h3))
+
+
+class AbstractEventLoopTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = unittest.mock.Mock()
+ loop = events.AbstractEventLoop()
+ self.assertRaises(
+ NotImplementedError, loop.run_forever)
+ self.assertRaises(
+ NotImplementedError, loop.run_until_complete, None)
+ self.assertRaises(
+ NotImplementedError, loop.stop)
+ self.assertRaises(
+ NotImplementedError, loop.is_running)
+ self.assertRaises(
+ NotImplementedError, loop.call_later, None, None)
+ self.assertRaises(
+ NotImplementedError, loop.call_at, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon, None)
+ self.assertRaises(
+ NotImplementedError, loop.time)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon_threadsafe, None)
+ self.assertRaises(
+ NotImplementedError, loop.run_in_executor, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.set_default_executor, f)
+ self.assertRaises(
+ NotImplementedError, loop.getaddrinfo, 'localhost', 8080)
+ self.assertRaises(
+ NotImplementedError, loop.getnameinfo, ('localhost', 8080))
+ self.assertRaises(
+ NotImplementedError, loop.create_connection, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_server, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_datagram_endpoint, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_reader, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_reader, 1)
+ self.assertRaises(
+ NotImplementedError, loop.add_writer, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_writer, 1)
+ self.assertRaises(
+ NotImplementedError, loop.sock_recv, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_sendall, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_connect, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.sock_accept, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_signal_handler, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.connect_read_pipe, f,
+ unittest.mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.connect_write_pipe, f,
+ unittest.mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_shell, f,
+ unittest.mock.sentinel)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_exec, f)
+
+
+class ProtocolsAbsTests(unittest.TestCase):
+
+ def test_empty(self):
+ f = unittest.mock.Mock()
+ p = protocols.Protocol()
+ self.assertIsNone(p.connection_made(f))
+ self.assertIsNone(p.connection_lost(f))
+ self.assertIsNone(p.data_received(f))
+ self.assertIsNone(p.eof_received())
+
+ dp = protocols.DatagramProtocol()
+ self.assertIsNone(dp.connection_made(f))
+ self.assertIsNone(dp.connection_lost(f))
+ self.assertIsNone(dp.connection_refused(f))
+ self.assertIsNone(dp.datagram_received(f, f))
+
+ sp = protocols.SubprocessProtocol()
+ self.assertIsNone(sp.connection_made(f))
+ self.assertIsNone(sp.connection_lost(f))
+ self.assertIsNone(sp.pipe_data_received(1, f))
+ self.assertIsNone(sp.pipe_connection_lost(1, f))
+ self.assertIsNone(sp.process_exited())
+
+
+class PolicyTests(unittest.TestCase):
+
+ def test_event_loop_policy(self):
+ policy = events.AbstractEventLoopPolicy()
+ self.assertRaises(NotImplementedError, policy.get_event_loop)
+ self.assertRaises(NotImplementedError, policy.set_event_loop, object())
+ self.assertRaises(NotImplementedError, policy.new_event_loop)
+
+ def test_get_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+ self.assertIsNone(policy._loop)
+
+ loop = policy.get_event_loop()
+ self.assertIsInstance(loop, events.AbstractEventLoop)
+
+ self.assertIs(policy._loop, loop)
+ self.assertIs(loop, policy.get_event_loop())
+ loop.close()
+
+ def test_get_event_loop_after_set_none(self):
+ policy = events.DefaultEventLoopPolicy()
+ policy.set_event_loop(None)
+ self.assertRaises(AssertionError, policy.get_event_loop)
+
+ @unittest.mock.patch('asyncio.events.threading.current_thread')
+ def test_get_event_loop_thread(self, m_current_thread):
+
+ def f():
+ policy = events.DefaultEventLoopPolicy()
+ self.assertRaises(AssertionError, policy.get_event_loop)
+
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+
+ def test_new_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+
+ loop = policy.new_event_loop()
+ self.assertIsInstance(loop, events.AbstractEventLoop)
+ loop.close()
+
+ def test_set_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+ old_loop = policy.get_event_loop()
+
+ self.assertRaises(AssertionError, policy.set_event_loop, object())
+
+ loop = policy.new_event_loop()
+ policy.set_event_loop(loop)
+ self.assertIs(loop, policy.get_event_loop())
+ self.assertIsNot(old_loop, policy.get_event_loop())
+ loop.close()
+ old_loop.close()
+
+ def test_get_event_loop_policy(self):
+ policy = events.get_event_loop_policy()
+ self.assertIsInstance(policy, events.AbstractEventLoopPolicy)
+ self.assertIs(policy, events.get_event_loop_policy())
+
+ def test_set_event_loop_policy(self):
+ self.assertRaises(
+ AssertionError, events.set_event_loop_policy, object())
+
+ old_policy = events.get_event_loop_policy()
+
+ policy = events.DefaultEventLoopPolicy()
+ events.set_event_loop_policy(policy)
+ self.assertIs(policy, events.get_event_loop_policy())
+ self.assertIsNot(policy, old_policy)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py
new file mode 100644
index 0000000..9b5108c
--- /dev/null
+++ b/Lib/test/test_asyncio/test_futures.py
@@ -0,0 +1,329 @@
+"""Tests for futures.py."""
+
+import concurrent.futures
+import threading
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import test_utils
+
+
+def _fakefunc(f):
+ return f
+
+
+class FutureTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_initial_state(self):
+ f = futures.Future(loop=self.loop)
+ self.assertFalse(f.cancelled())
+ self.assertFalse(f.done())
+ f.cancel()
+ self.assertTrue(f.cancelled())
+
+ def test_init_constructor_default_loop(self):
+ try:
+ events.set_event_loop(self.loop)
+ f = futures.Future()
+ self.assertIs(f._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_constructor_positional(self):
+ # Make sure Future does't accept a positional argument
+ self.assertRaises(TypeError, futures.Future, 42)
+
+ def test_cancel(self):
+ f = futures.Future(loop=self.loop)
+ self.assertTrue(f.cancel())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(futures.CancelledError, f.result)
+ self.assertRaises(futures.CancelledError, f.exception)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_result(self):
+ f = futures.Future(loop=self.loop)
+ self.assertRaises(futures.InvalidStateError, f.result)
+
+ f.set_result(42)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 42)
+ self.assertEqual(f.exception(), None)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_exception(self):
+ exc = RuntimeError()
+ f = futures.Future(loop=self.loop)
+ self.assertRaises(futures.InvalidStateError, f.exception)
+
+ f.set_exception(exc)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(RuntimeError, f.result)
+ self.assertEqual(f.exception(), exc)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_yield_from_twice(self):
+ f = futures.Future(loop=self.loop)
+
+ def fixture():
+ yield 'A'
+ x = yield from f
+ yield 'B', x
+ y = yield from f
+ yield 'C', y
+
+ g = fixture()
+ self.assertEqual(next(g), 'A') # yield 'A'.
+ self.assertEqual(next(g), f) # First yield from f.
+ f.set_result(42)
+ self.assertEqual(next(g), ('B', 42)) # yield 'B', x.
+ # The second "yield from f" does not yield f.
+ self.assertEqual(next(g), ('C', 42)) # yield 'C', y.
+
+ def test_repr(self):
+ f_pending = futures.Future(loop=self.loop)
+ self.assertEqual(repr(f_pending), 'Future<PENDING>')
+ f_pending.cancel()
+
+ f_cancelled = futures.Future(loop=self.loop)
+ f_cancelled.cancel()
+ self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>')
+
+ f_result = futures.Future(loop=self.loop)
+ f_result.set_result(4)
+ self.assertEqual(repr(f_result), 'Future<result=4>')
+ self.assertEqual(f_result.result(), 4)
+
+ exc = RuntimeError()
+ f_exception = futures.Future(loop=self.loop)
+ f_exception.set_exception(exc)
+ self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>')
+ self.assertIs(f_exception.exception(), exc)
+
+ f_few_callbacks = futures.Future(loop=self.loop)
+ f_few_callbacks.add_done_callback(_fakefunc)
+ self.assertIn('Future<PENDING, [<function _fakefunc',
+ repr(f_few_callbacks))
+ f_few_callbacks.cancel()
+
+ f_many_callbacks = futures.Future(loop=self.loop)
+ for i in range(20):
+ f_many_callbacks.add_done_callback(_fakefunc)
+ r = repr(f_many_callbacks)
+ self.assertIn('Future<PENDING, [<function _fakefunc', r)
+ self.assertIn('<18 more>', r)
+ f_many_callbacks.cancel()
+
+ def test_copy_state(self):
+ # Test the internal _copy_state method since it's being directly
+ # invoked in other modules.
+ f = futures.Future(loop=self.loop)
+ f.set_result(10)
+
+ newf = futures.Future(loop=self.loop)
+ newf._copy_state(f)
+ self.assertTrue(newf.done())
+ self.assertEqual(newf.result(), 10)
+
+ f_exception = futures.Future(loop=self.loop)
+ f_exception.set_exception(RuntimeError())
+
+ newf_exception = futures.Future(loop=self.loop)
+ newf_exception._copy_state(f_exception)
+ self.assertTrue(newf_exception.done())
+ self.assertRaises(RuntimeError, newf_exception.result)
+
+ f_cancelled = futures.Future(loop=self.loop)
+ f_cancelled.cancel()
+
+ newf_cancelled = futures.Future(loop=self.loop)
+ newf_cancelled._copy_state(f_cancelled)
+ self.assertTrue(newf_cancelled.cancelled())
+
+ def test_iter(self):
+ fut = futures.Future(loop=self.loop)
+
+ def coro():
+ yield from fut
+
+ def test():
+ arg1, arg2 = coro()
+
+ self.assertRaises(AssertionError, test)
+ fut.cancel()
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_abandoned(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_result_unretrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_result_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ fut.result()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_exception_unretrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ del fut
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_exception_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ fut.exception()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_exception_result_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ self.assertRaises(RuntimeError, fut.result)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ def test_wrap_future(self):
+
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = futures.wrap_future(f1, loop=self.loop)
+ res, ident = self.loop.run_until_complete(f2)
+ self.assertIsInstance(f2, futures.Future)
+ self.assertEqual(res, 'oi')
+ self.assertNotEqual(ident, threading.get_ident())
+
+ def test_wrap_future_future(self):
+ f1 = futures.Future(loop=self.loop)
+ f2 = futures.wrap_future(f1)
+ self.assertIs(f1, f2)
+
+ @unittest.mock.patch('asyncio.futures.events')
+ def test_wrap_future_use_global_loop(self, m_events):
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = futures.wrap_future(f1)
+ self.assertIs(m_events.get_event_loop.return_value, f2._loop)
+
+
+class FutureDoneCallbackTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def run_briefly(self):
+ test_utils.run_briefly(self.loop)
+
+ def _make_callback(self, bag, thing):
+ # Create a callback function that appends thing to bag.
+ def bag_appender(future):
+ bag.append(thing)
+ return bag_appender
+
+ def _new_future(self):
+ return futures.Future(loop=self.loop)
+
+ def test_callbacks_invoked_on_set_result(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 42))
+ f.add_done_callback(self._make_callback(bag, 17))
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [42, 17])
+ self.assertEqual(f.result(), 'foo')
+
+ def test_callbacks_invoked_on_set_exception(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 100))
+
+ self.assertEqual(bag, [])
+ exc = RuntimeError()
+ f.set_exception(exc)
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [100])
+ self.assertEqual(f.exception(), exc)
+
+ def test_remove_done_callback(self):
+ bag = []
+ f = self._new_future()
+ cb1 = self._make_callback(bag, 1)
+ cb2 = self._make_callback(bag, 2)
+ cb3 = self._make_callback(bag, 3)
+
+ # Add one cb1 and one cb2.
+ f.add_done_callback(cb1)
+ f.add_done_callback(cb2)
+
+ # One instance of cb2 removed. Now there's only one cb1.
+ self.assertEqual(f.remove_done_callback(cb2), 1)
+
+ # Never had any cb3 in there.
+ self.assertEqual(f.remove_done_callback(cb3), 0)
+
+ # After this there will be 6 instances of cb1 and one of cb2.
+ f.add_done_callback(cb2)
+ for i in range(5):
+ f.add_done_callback(cb1)
+
+ # Remove all instances of cb1. One cb2 remains.
+ self.assertEqual(f.remove_done_callback(cb1), 6)
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [2])
+ self.assertEqual(f.result(), 'foo')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
new file mode 100644
index 0000000..31b4d64
--- /dev/null
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -0,0 +1,765 @@
+"""Tests for lock.py"""
+
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import locks
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class LockTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ lock = locks.Lock(loop=loop)
+ self.assertIs(lock._loop, loop)
+
+ lock = locks.Lock(loop=self.loop)
+ self.assertIs(lock._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ lock = locks.Lock()
+ self.assertIs(lock._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ lock = locks.Lock(loop=self.loop)
+ self.assertTrue(repr(lock).endswith('[unlocked]>'))
+
+ @tasks.coroutine
+ def acquire_lock():
+ yield from lock
+
+ self.loop.run_until_complete(acquire_lock())
+ self.assertTrue(repr(lock).endswith('[locked]>'))
+
+ def test_lock(self):
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_acquire(self):
+ lock = locks.Lock(loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from lock.acquire()):
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from lock.acquire()):
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from lock.acquire()):
+ result.append(3)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_acquire_cancel(self):
+ lock = locks.Lock(loop=self.loop)
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ task = tasks.Task(lock.acquire(), loop=self.loop)
+ self.loop.call_soon(task.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, task)
+ self.assertFalse(lock._waiters)
+
+ def test_cancel_race(self):
+ # Several tasks:
+ # - A acquires the lock
+ # - B is blocked in aqcuire()
+ # - C is blocked in aqcuire()
+ #
+ # Now, concurrently:
+ # - B is cancelled
+ # - A releases the lock
+ #
+ # If B's waiter is marked cancelled but not yet removed from
+ # _waiters, A's release() call will crash when trying to set
+ # B's waiter; instead, it should move on to C's waiter.
+
+ # Setup: A has the lock, b and c are waiting.
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def lockit(name, blocker):
+ yield from lock.acquire()
+ try:
+ if blocker is not None:
+ yield from blocker
+ finally:
+ lock.release()
+
+ fa = futures.Future(loop=self.loop)
+ ta = tasks.Task(lockit('A', fa), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(lock.locked())
+ tb = tasks.Task(lockit('B', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 1)
+ tc = tasks.Task(lockit('C', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 2)
+
+ # Create the race and check.
+ # Without the fix this failed at the last assert.
+ fa.set_result(None)
+ tb.cancel()
+ self.assertTrue(lock._waiters[0].cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(lock.locked())
+ self.assertTrue(ta.done())
+ self.assertTrue(tb.cancelled())
+ self.assertTrue(tc.done())
+
+ def test_release_not_acquired(self):
+ lock = locks.Lock(loop=self.loop)
+
+ self.assertRaises(RuntimeError, lock.release)
+
+ def test_release_no_waiters(self):
+ lock = locks.Lock(loop=self.loop)
+ self.loop.run_until_complete(lock.acquire())
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_context_manager(self):
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(lock.locked())
+
+ self.assertFalse(lock.locked())
+
+ def test_context_manager_no_yield(self):
+ lock = locks.Lock(loop=self.loop)
+
+ try:
+ with lock:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+
+class EventTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ ev = locks.Event(loop=loop)
+ self.assertIs(ev._loop, loop)
+
+ ev = locks.Event(loop=self.loop)
+ self.assertIs(ev._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ ev = locks.Event()
+ self.assertIs(ev._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertTrue(repr(ev).endswith('[unset]>'))
+
+ ev.set()
+ self.assertTrue(repr(ev).endswith('[set]>'))
+
+ def test_wait(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from ev.wait()):
+ result.append(2)
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from ev.wait()):
+ result.append(3)
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ ev.set()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([3, 1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertIsNone(t1.result())
+ self.assertTrue(t2.done())
+ self.assertIsNone(t2.result())
+ self.assertTrue(t3.done())
+ self.assertIsNone(t3.result())
+
+ def test_wait_on_set(self):
+ ev = locks.Event(loop=self.loop)
+ ev.set()
+
+ res = self.loop.run_until_complete(ev.wait())
+ self.assertTrue(res)
+
+ def test_wait_cancel(self):
+ ev = locks.Event(loop=self.loop)
+
+ wait = tasks.Task(ev.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(ev._waiters)
+
+ def test_clear(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ self.assertTrue(ev.is_set())
+
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ def test_clear_with_waiters(self):
+ ev = locks.Event(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+ return True
+
+ t = tasks.Task(c1(result), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ ev.set()
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ ev.set()
+ self.assertEqual(1, len(ev._waiters))
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertEqual(0, len(ev._waiters))
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+
+class ConditionTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ cond = locks.Condition(loop=loop)
+ self.assertIs(cond._loop, loop)
+
+ cond = locks.Condition(loop=self.loop)
+ self.assertIs(cond._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ cond = locks.Condition()
+ self.assertIs(cond._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_wait(self):
+ cond = locks.Condition(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertFalse(cond.locked())
+
+ self.assertTrue(self.loop.run_until_complete(cond.acquire()))
+ cond.notify()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.notify(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(cond.locked())
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_wait_cancel(self):
+ cond = locks.Condition(loop=self.loop)
+ self.loop.run_until_complete(cond.acquire())
+
+ wait = tasks.Task(cond.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(cond._condition_waiters)
+ self.assertTrue(cond.locked())
+
+ def test_wait_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, cond.wait())
+
+ def test_wait_for(self):
+ cond = locks.Condition(loop=self.loop)
+ presult = False
+
+ def predicate():
+ return presult
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait_for(predicate)):
+ result.append(1)
+ cond.release()
+ return True
+
+ t = tasks.Task(c1(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ presult = True
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_wait_for_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+
+ # predicate can return true immediately
+ res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3]))
+ self.assertEqual([1, 2, 3], res)
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete,
+ cond.wait_for(lambda: False))
+
+ def test_notify(self):
+ cond = locks.Condition(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ cond.release()
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.notify(2048)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_notify_all(self):
+ cond = locks.Condition(loop=self.loop)
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify_all()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+
+ def test_notify_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify)
+
+ def test_notify_all_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify_all)
+
+
+class SemaphoreTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ sem = locks.Semaphore(loop=loop)
+ self.assertIs(sem._loop, loop)
+
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertIs(sem._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ sem = locks.Semaphore()
+ self.assertIs(sem._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertTrue(repr(sem).endswith('[unlocked,value:1]>'))
+
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(repr(sem).endswith('[locked]>'))
+
+ def test_semaphore(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertEqual(1, sem._value)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(sem.locked())
+ self.assertEqual(0, sem._value)
+
+ sem.release()
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ def test_semaphore_value(self):
+ self.assertRaises(ValueError, locks.Semaphore, -1)
+
+ def test_acquire(self):
+ sem = locks.Semaphore(3, loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertFalse(sem.locked())
+
+ @tasks.coroutine
+ def c1(result):
+ yield from sem.acquire()
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from sem.acquire()
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from sem.acquire()
+ result.append(3)
+ return True
+
+ @tasks.coroutine
+ def c4(result):
+ yield from sem.acquire()
+ result.append(4)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(2, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ t4 = tasks.Task(c4(result), loop=self.loop)
+
+ sem.release()
+ sem.release()
+ self.assertEqual(2, sem._value)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(0, sem._value)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(1, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+ self.assertFalse(t4.done())
+
+ # cleanup locked semaphore
+ sem.release()
+
+ def test_acquire_cancel(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+
+ acquire = tasks.Task(sem.acquire(), loop=self.loop)
+ self.loop.call_soon(acquire.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, acquire)
+ self.assertFalse(sem._waiters)
+
+ def test_release_not_acquired(self):
+ sem = locks.Semaphore(bound=True, loop=self.loop)
+
+ self.assertRaises(ValueError, sem.release)
+
+ def test_release_no_waiters(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(sem.locked())
+
+ sem.release()
+ self.assertFalse(sem.locked())
+
+ def test_context_manager(self):
+ sem = locks.Semaphore(2, loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(sem.locked())
+
+ self.assertEqual(2, sem._value)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
new file mode 100644
index 0000000..c52ade0
--- /dev/null
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -0,0 +1,480 @@
+"""Tests for proactor_events.py"""
+
+import socket
+import unittest
+import unittest.mock
+
+import asyncio
+from asyncio.proactor_events import BaseProactorEventLoop
+from asyncio.proactor_events import _ProactorSocketTransport
+from asyncio.proactor_events import _ProactorWritePipeTransport
+from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from asyncio import test_utils
+
+
+class ProactorSocketTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.proactor = unittest.mock.Mock()
+ self.loop._proactor = self.proactor
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+
+ def test_ctor(self):
+ fut = asyncio.Future(loop=self.loop)
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+ self.protocol.connection_made(tr)
+ self.proactor.recv.assert_called_with(self.sock, 4096)
+
+ def test_loop_reading(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._loop_reading()
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertFalse(self.protocol.eof_received.called)
+
+ def test_loop_reading_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'data')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_loop_reading_no_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+
+ self.assertRaises(AssertionError, tr._loop_reading, res)
+
+ tr.close = unittest.mock.Mock()
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.assertFalse(self.loop._proactor.recv.called)
+ self.assertTrue(self.protocol.eof_received.called)
+ self.assertTrue(tr.close.called)
+
+ def test_loop_reading_aborted(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(err)
+
+ def test_loop_reading_aborted_closing(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = True
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+
+ def test_loop_reading_aborted_is_fatal(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = False
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertTrue(tr._fatal_error.called)
+
+ def test_loop_reading_conn_reset_lost(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionResetError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = False
+ tr._fatal_error = unittest.mock.Mock()
+ tr._force_close = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+ tr._force_close.assert_called_with(err)
+
+ def test_loop_reading_exception(self):
+ err = self.loop._proactor.recv.side_effect = (OSError())
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(err)
+
+ def test_write(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._loop_writing = unittest.mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [b'data'])
+ self.assertTrue(tr._loop_writing.called)
+
+ def test_write_no_data(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr.write(b'')
+ self.assertFalse(tr._buffer)
+
+ def test_write_more(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = unittest.mock.Mock()
+ tr._loop_writing = unittest.mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [b'data'])
+ self.assertFalse(tr._loop_writing.called)
+
+ def test_loop_writing(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'da', b'ta']
+ tr._loop_writing()
+ self.loop._proactor.send.assert_called_with(self.sock, b'data')
+ self.loop._proactor.send.return_value.add_done_callback.\
+ assert_called_with(tr._loop_writing)
+
+ @unittest.mock.patch('asyncio.proactor_events.asyncio_log')
+ def test_loop_writing_err(self, m_log):
+ err = self.loop._proactor.send.side_effect = OSError()
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._buffer = [b'da', b'ta']
+ tr._loop_writing()
+ tr._fatal_error.assert_called_with(err)
+ tr._conn_lost = 1
+
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [])
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_loop_writing_stop(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(b'data')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = fut
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+
+ def test_loop_writing_closing(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(1)
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = fut
+ tr.close()
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_abort(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._force_close = unittest.mock.Mock()
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertTrue(tr._closing)
+ self.assertEqual(tr._conn_lost, 1)
+
+ self.protocol.connection_lost.reset_mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_write_fut(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = unittest.mock.Mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_buffer(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ @unittest.mock.patch('asyncio.proactor_events.asyncio_log')
+ def test_fatal_error(self, m_logging):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._force_close = unittest.mock.Mock()
+ tr._fatal_error(None)
+ self.assertTrue(tr._force_close.called)
+ self.assertTrue(m_logging.exception.called)
+
+ def test_force_close(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ read_fut = tr._read_fut = unittest.mock.Mock()
+ write_fut = tr._write_fut = unittest.mock.Mock()
+ tr._force_close(None)
+
+ read_fut.cancel.assert_called_with()
+ write_fut.cancel.assert_called_with()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual([], tr._buffer)
+ self.assertEqual(tr._conn_lost, 1)
+
+ def test_force_close_idempotent(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = True
+ tr._force_close(None)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_fatal_error_2(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ tr._force_close(None)
+
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual([], tr._buffer)
+
+ def test_call_connection_lost(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._call_connection_lost(None)
+ self.assertTrue(self.protocol.connection_lost.called)
+ self.assertTrue(self.sock.close.called)
+
+ def test_write_eof(self):
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr._eof_written)
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+ def test_write_eof_write_pipe(self):
+ tr = _ProactorWritePipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_buffer_write_pipe(self):
+ tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol)
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_duplex_pipe(self):
+ tr = _ProactorDuplexPipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertFalse(tr.can_write_eof())
+ with self.assertRaises(NotImplementedError):
+ tr.write_eof()
+ tr.close()
+
+ def test_pause_resume(self):
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ futures = []
+ for msg in [b'data1', b'data2', b'data3', b'data4', b'']:
+ f = asyncio.Future(loop=self.loop)
+ f.set_result(msg)
+ futures.append(f)
+ self.loop._proactor.recv.side_effect = futures
+ self.loop._run_once()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data1')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.pause()
+ self.assertTrue(tr._paused)
+ for i in range(10):
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.resume()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data3')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data4')
+ tr.close()
+
+
+class BaseProactorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.proactor = unittest.mock.Mock()
+
+ self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock()
+
+ class EventLoop(BaseProactorEventLoop):
+ def _socketpair(s):
+ return (self.ssock, self.csock)
+
+ self.loop = EventLoop(self.proactor)
+
+ @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon')
+ @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair')
+ def test_ctor(self, socketpair, call_soon):
+ ssock, csock = socketpair.return_value = (
+ unittest.mock.Mock(), unittest.mock.Mock())
+ loop = BaseProactorEventLoop(self.proactor)
+ self.assertIs(loop._ssock, ssock)
+ self.assertIs(loop._csock, csock)
+ self.assertEqual(loop._internal_fds, 1)
+ call_soon.assert_called_with(loop._loop_self_reading)
+
+ def test_close_self_pipe(self):
+ self.loop._close_self_pipe()
+ self.assertEqual(self.loop._internal_fds, 0)
+ self.assertTrue(self.ssock.close.called)
+ self.assertTrue(self.csock.close.called)
+ self.assertIsNone(self.loop._ssock)
+ self.assertIsNone(self.loop._csock)
+
+ def test_close(self):
+ self.loop._close_self_pipe = unittest.mock.Mock()
+ self.loop.close()
+ self.assertTrue(self.loop._close_self_pipe.called)
+ self.assertTrue(self.proactor.close.called)
+ self.assertIsNone(self.loop._proactor)
+
+ self.loop._close_self_pipe.reset_mock()
+ self.loop.close()
+ self.assertFalse(self.loop._close_self_pipe.called)
+
+ def test_sock_recv(self):
+ self.loop.sock_recv(self.sock, 1024)
+ self.proactor.recv.assert_called_with(self.sock, 1024)
+
+ def test_sock_sendall(self):
+ self.loop.sock_sendall(self.sock, b'data')
+ self.proactor.send.assert_called_with(self.sock, b'data')
+
+ def test_sock_connect(self):
+ self.loop.sock_connect(self.sock, 123)
+ self.proactor.connect.assert_called_with(self.sock, 123)
+
+ def test_sock_accept(self):
+ self.loop.sock_accept(self.sock)
+ self.proactor.accept.assert_called_with(self.sock)
+
+ def test_socketpair(self):
+ self.assertRaises(
+ NotImplementedError, BaseProactorEventLoop, self.proactor)
+
+ def test_make_socket_transport(self):
+ tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
+ self.assertIsInstance(tr, _ProactorSocketTransport)
+
+ def test_loop_self_reading(self):
+ self.loop._loop_self_reading()
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_fut(self):
+ fut = unittest.mock.Mock()
+ self.loop._loop_self_reading(fut)
+ self.assertTrue(fut.result.called)
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_exception(self):
+ self.loop.close = unittest.mock.Mock()
+ self.proactor.recv.side_effect = OSError()
+ self.assertRaises(OSError, self.loop._loop_self_reading)
+ self.assertTrue(self.loop.close.called)
+
+ def test_write_to_self(self):
+ self.loop._write_to_self()
+ self.csock.send.assert_called_with(b'x')
+
+ def test_process_events(self):
+ self.loop._process_events([])
+
+ @unittest.mock.patch('asyncio.proactor_events.asyncio_log')
+ def test_create_server(self, m_log):
+ pf = unittest.mock.Mock()
+ call_soon = self.loop.call_soon = unittest.mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ self.assertTrue(call_soon.called)
+
+ # callback
+ loop = call_soon.call_args[0][0]
+ loop()
+ self.proactor.accept.assert_called_with(self.sock)
+
+ # conn
+ fut = unittest.mock.Mock()
+ fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock())
+
+ make_tr = self.loop._make_socket_transport = unittest.mock.Mock()
+ loop(fut)
+ self.assertTrue(fut.result.called)
+ self.assertTrue(make_tr.called)
+
+ # exception
+ fut.result.side_effect = OSError()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+ self.assertTrue(m_log.exception.called)
+
+ def test_create_server_cancel(self):
+ pf = unittest.mock.Mock()
+ call_soon = self.loop.call_soon = unittest.mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ loop = call_soon.call_args[0][0]
+
+ # cancelled
+ fut = asyncio.Future(loop=self.loop)
+ fut.cancel()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+
+ def test_stop_serving(self):
+ sock = unittest.mock.Mock()
+ self.loop._stop_serving(sock)
+ self.assertTrue(sock.close.called)
+ self.proactor._stop_serving.assert_called_with(sock)
diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py
new file mode 100644
index 0000000..8af4ee7
--- /dev/null
+++ b/Lib/test/test_asyncio/test_queues.py
@@ -0,0 +1,470 @@
+"""Tests for queues.py"""
+
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import locks
+from asyncio import queues
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class _QueueTestBase(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+
+class QueueBasicTests(_QueueTestBase):
+
+ def _test_repr_or_str(self, fn, expect_id):
+ """Test Queue's repr or str.
+
+ fn is repr or str. expect_id is True if we expect the Queue's id to
+ appear in fn(Queue()).
+ """
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+ self.assertTrue(fn(q).startswith('<Queue'), fn(q))
+ id_is_present = hex(id(q)) in fn(q)
+ self.assertEqual(expect_id, id_is_present)
+
+ @tasks.coroutine
+ def add_getter():
+ q = queues.Queue(loop=loop)
+ # Start a task that waits to get.
+ tasks.Task(q.get(), loop=loop)
+ # Let it start waiting.
+ yield from tasks.sleep(0.1, loop=loop)
+ self.assertTrue('_getters[1]' in fn(q))
+ # resume q.get coroutine to finish generator
+ q.put_nowait(0)
+
+ loop.run_until_complete(add_getter())
+
+ @tasks.coroutine
+ def add_putter():
+ q = queues.Queue(maxsize=1, loop=loop)
+ q.put_nowait(1)
+ # Start a task that waits to put.
+ tasks.Task(q.put(2), loop=loop)
+ # Let it start waiting.
+ yield from tasks.sleep(0.1, loop=loop)
+ self.assertTrue('_putters[1]' in fn(q))
+ # resume q.put coroutine to finish generator
+ q.get_nowait()
+
+ loop.run_until_complete(add_putter())
+
+ q = queues.Queue(loop=loop)
+ q.put_nowait(1)
+ self.assertTrue('_queue=[1]' in fn(q))
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ q = queues.Queue(loop=loop)
+ self.assertIs(q._loop, loop)
+
+ q = queues.Queue(loop=self.loop)
+ self.assertIs(q._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ q = queues.Queue()
+ self.assertIs(q._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ self._test_repr_or_str(repr, True)
+
+ def test_str(self):
+ self._test_repr_or_str(str, False)
+
+ def test_empty(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertTrue(q.empty())
+ q.put_nowait(1)
+ self.assertFalse(q.empty())
+ self.assertEqual(1, q.get_nowait())
+ self.assertTrue(q.empty())
+
+ def test_full(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertFalse(q.full())
+
+ q = queues.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertTrue(q.full())
+
+ def test_order(self):
+ q = queues.Queue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 3, 2], items)
+
+ def test_maxsize(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.02, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(maxsize=2, loop=loop)
+ self.assertEqual(2, q.maxsize)
+ have_been_put = []
+
+ @tasks.coroutine
+ def putter():
+ for i in range(3):
+ yield from q.put(i)
+ have_been_put.append(i)
+ return True
+
+ @tasks.coroutine
+ def test():
+ t = tasks.Task(putter(), loop=loop)
+ yield from tasks.sleep(0.01, loop=loop)
+
+ # The putter is blocked after putting two items.
+ self.assertEqual([0, 1], have_been_put)
+ self.assertEqual(0, q.get_nowait())
+
+ # Let the putter resume and put last item.
+ yield from tasks.sleep(0.01, loop=loop)
+ self.assertEqual([0, 1, 2], have_been_put)
+ self.assertEqual(1, q.get_nowait())
+ self.assertEqual(2, q.get_nowait())
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ loop.run_until_complete(test())
+ self.assertAlmostEqual(0.02, loop.time())
+
+
+class QueueGetTests(_QueueTestBase):
+
+ def test_blocking_get(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from q.get())
+
+ res = self.loop.run_until_complete(queue_get())
+ self.assertEqual(1, res)
+
+ def test_get_with_putters(self):
+ q = queues.Queue(1, loop=self.loop)
+ q.put_nowait(1)
+
+ waiter = futures.Future(loop=self.loop)
+ q._putters.append((2, waiter))
+
+ res = self.loop.run_until_complete(q.get())
+ self.assertEqual(1, res)
+ self.assertTrue(waiter.done())
+ self.assertIsNone(waiter.result())
+
+ def test_blocking_get_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+ started = locks.Event(loop=loop)
+ finished = False
+
+ @tasks.coroutine
+ def queue_get():
+ nonlocal finished
+ started.set()
+ res = yield from q.get()
+ finished = True
+ return res
+
+ @tasks.coroutine
+ def queue_put():
+ loop.call_later(0.01, q.put_nowait, 1)
+ queue_get_task = tasks.Task(queue_get(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ res = yield from queue_get_task
+ self.assertTrue(finished)
+ return res
+
+ res = loop.run_until_complete(queue_put())
+ self.assertEqual(1, res)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_get(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_get_exception(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertRaises(queues.Empty, q.get_nowait)
+
+ def test_get_cancelled(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.061, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from tasks.wait_for(q.get(), 0.051, loop=loop))
+
+ @tasks.coroutine
+ def test():
+ get_task = tasks.Task(queue_get(), loop=loop)
+ yield from tasks.sleep(0.01, loop=loop) # let the task start
+ q.put_nowait(1)
+ return (yield from get_task)
+
+ self.assertEqual(1, loop.run_until_complete(test()))
+ self.assertAlmostEqual(0.06, loop.time())
+
+ def test_get_cancelled_race(self):
+ q = queues.Queue(loop=self.loop)
+
+ t1 = tasks.Task(q.get(), loop=self.loop)
+ t2 = tasks.Task(q.get(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ t1.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(t1.done())
+ q.put_nowait('a')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(t2.result(), 'a')
+
+ def test_get_with_waiting_putters(self):
+ q = queues.Queue(loop=self.loop, maxsize=1)
+ tasks.Task(q.put('a'), loop=self.loop)
+ tasks.Task(q.put('b'), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'a')
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'b')
+
+
+class QueuePutTests(_QueueTestBase):
+
+ def test_blocking_put(self):
+ q = queues.Queue(loop=self.loop)
+
+ @tasks.coroutine
+ def queue_put():
+ # No maxsize, won't block.
+ yield from q.put(1)
+
+ self.loop.run_until_complete(queue_put())
+
+ def test_blocking_put_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(maxsize=1, loop=loop)
+ started = locks.Event(loop=loop)
+ finished = False
+
+ @tasks.coroutine
+ def queue_put():
+ nonlocal finished
+ started.set()
+ yield from q.put(1)
+ yield from q.put(2)
+ finished = True
+
+ @tasks.coroutine
+ def queue_get():
+ loop.call_later(0.01, q.get_nowait)
+ queue_put_task = tasks.Task(queue_put(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ yield from queue_put_task
+ self.assertTrue(finished)
+
+ loop.run_until_complete(queue_get())
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_put(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_put_exception(self):
+ q = queues.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertRaises(queues.Full, q.put_nowait, 2)
+
+ def test_put_cancelled(self):
+ q = queues.Queue(loop=self.loop)
+
+ @tasks.coroutine
+ def queue_put():
+ yield from q.put(1)
+ return True
+
+ @tasks.coroutine
+ def test():
+ return (yield from q.get())
+
+ t = tasks.Task(queue_put(), loop=self.loop)
+ self.assertEqual(1, self.loop.run_until_complete(test()))
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_put_cancelled_race(self):
+ q = queues.Queue(loop=self.loop, maxsize=1)
+
+ tasks.Task(q.put('a'), loop=self.loop)
+ tasks.Task(q.put('c'), loop=self.loop)
+ t = tasks.Task(q.put('b'), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ t.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(t.done())
+ self.assertEqual(q.get_nowait(), 'a')
+ self.assertEqual(q.get_nowait(), 'c')
+
+ def test_put_with_waiting_getters(self):
+ q = queues.Queue(loop=self.loop)
+ t = tasks.Task(q.get(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.loop.run_until_complete(q.put('a'))
+ self.assertEqual(self.loop.run_until_complete(t), 'a')
+
+
+class LifoQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.LifoQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([2, 3, 1], items)
+
+
+class PriorityQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.PriorityQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 2, 3], items)
+
+
+class JoinableQueueTests(_QueueTestBase):
+
+ def test_task_done_underflow(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ self.assertRaises(ValueError, q.task_done)
+
+ def test_task_done(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ for i in range(100):
+ q.put_nowait(i)
+
+ accumulator = 0
+
+ # Two workers get items from the queue and call task_done after each.
+ # Join the queue and assert all items have been processed.
+ running = True
+
+ @tasks.coroutine
+ def worker():
+ nonlocal accumulator
+
+ while running:
+ item = yield from q.get()
+ accumulator += item
+ q.task_done()
+
+ @tasks.coroutine
+ def test():
+ for _ in range(2):
+ tasks.Task(worker(), loop=self.loop)
+
+ yield from q.join()
+
+ self.loop.run_until_complete(test())
+ self.assertEqual(sum(range(100)), accumulator)
+
+ # close running generators
+ running = False
+ for i in range(2):
+ q.put_nowait(0)
+
+ def test_join_empty_queue(self):
+ q = queues.JoinableQueue(loop=self.loop)
+
+ # Test that a queue join()s successfully, and before anything else
+ # (done twice for insurance).
+
+ @tasks.coroutine
+ def join():
+ yield from q.join()
+ yield from q.join()
+
+ self.loop.run_until_complete(join())
+
+ def test_format(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ self.assertEqual(q._format(), 'maxsize=0')
+
+ q._unfinished_tasks = 2
+ self.assertEqual(q._format(), 'maxsize=0 tasks=2')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
new file mode 100644
index 0000000..0225e13
--- /dev/null
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -0,0 +1,1485 @@
+"""Tests for selector_events.py"""
+
+import collections
+import errno
+import gc
+import pprint
+import socket
+import sys
+import unittest
+import unittest.mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from asyncio import futures
+from asyncio import selectors
+from asyncio import test_utils
+from asyncio.protocols import DatagramProtocol, Protocol
+from asyncio.selector_events import BaseSelectorEventLoop
+from asyncio.selector_events import _SelectorTransport
+from asyncio.selector_events import _SelectorSslTransport
+from asyncio.selector_events import _SelectorSocketTransport
+from asyncio.selector_events import _SelectorDatagramTransport
+
+
+class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
+
+ def _make_self_pipe(self):
+ self._ssock = unittest.mock.Mock()
+ self._csock = unittest.mock.Mock()
+ self._internal_fds += 1
+
+
+class BaseSelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = TestBaseSelectorEventLoop(unittest.mock.Mock())
+
+ def test_make_socket_transport(self):
+ m = unittest.mock.Mock()
+ self.loop.add_reader = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
+
+ def test_make_ssl_transport(self):
+ m = unittest.mock.Mock()
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
+
+ def test_close(self):
+ ssock = self.loop._ssock
+ ssock.fileno.return_value = 7
+ csock = self.loop._csock
+ csock.fileno.return_value = 1
+ remove_reader = self.loop.remove_reader = unittest.mock.Mock()
+
+ self.loop._selector.close()
+ self.loop._selector = selector = unittest.mock.Mock()
+ self.loop.close()
+ self.assertIsNone(self.loop._selector)
+ self.assertIsNone(self.loop._csock)
+ self.assertIsNone(self.loop._ssock)
+ selector.close.assert_called_with()
+ ssock.close.assert_called_with()
+ csock.close.assert_called_with()
+ remove_reader.assert_called_with(7)
+
+ self.loop.close()
+ self.loop.close()
+
+ def test_close_no_selector(self):
+ ssock = self.loop._ssock
+ csock = self.loop._csock
+ remove_reader = self.loop.remove_reader = unittest.mock.Mock()
+
+ self.loop._selector.close()
+ self.loop._selector = None
+ self.loop.close()
+ self.assertIsNone(self.loop._selector)
+ self.assertFalse(ssock.close.called)
+ self.assertFalse(csock.close.called)
+ self.assertFalse(remove_reader.called)
+
+ def test_socketpair(self):
+ self.assertRaises(NotImplementedError, self.loop._socketpair)
+
+ def test_read_from_self_tryagain(self):
+ self.loop._ssock.recv.side_effect = BlockingIOError
+ self.assertIsNone(self.loop._read_from_self())
+
+ def test_read_from_self_exception(self):
+ self.loop._ssock.recv.side_effect = OSError
+ self.assertRaises(OSError, self.loop._read_from_self)
+
+ def test_write_to_self_tryagain(self):
+ self.loop._csock.send.side_effect = BlockingIOError
+ self.assertIsNone(self.loop._write_to_self())
+
+ def test_write_to_self_exception(self):
+ self.loop._csock.send.side_effect = OSError()
+ self.assertRaises(OSError, self.loop._write_to_self)
+
+ def test_sock_recv(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_recv = unittest.mock.Mock()
+
+ f = self.loop.sock_recv(sock, 1024)
+ self.assertIsInstance(f, futures.Future)
+ self.loop._sock_recv.assert_called_with(f, False, sock, 1024)
+
+ def test__sock_recv_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertFalse(sock.recv.called)
+
+ def test__sock_recv_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._sock_recv(f, True, sock, 1024)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_recv_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.recv.side_effect = BlockingIOError
+
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_recv_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.recv.side_effect = OSError()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertIs(err, f.exception())
+
+ def test_sock_sendall(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'data')
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, b'data'),
+ self.loop._sock_sendall.call_args[0])
+
+ def test_sock_sendall_nodata(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'')
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertFalse(self.loop._sock_sendall.called)
+
+ def test__sock_sendall_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(sock.send.called)
+
+ def test__sock_sendall_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, True, sock, b'data')
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_sendall_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = BlockingIOError
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_interrupted(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = InterruptedError
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.send.side_effect = OSError()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertIs(f.exception(), err)
+
+ def test__sock_sendall(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 4
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test__sock_sendall_partial(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 2
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'ta'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_none(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 0
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test_sock_connect(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_connect = unittest.mock.Mock()
+
+ f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, ('127.0.0.1', 8080)),
+ self.loop._sock_connect.call_args[0])
+
+ def test__sock_connect(self):
+ f = futures.Future(loop=self.loop)
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertTrue(sock.connect.called)
+
+ def test__sock_connect_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertFalse(sock.connect.called)
+
+ def test__sock_connect_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_connect_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.EAGAIN
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop.remove_writer = unittest.mock.Mock()
+
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual(
+ (10, self.loop._sock_connect, f,
+ True, sock, ('127.0.0.1', 8080)),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_connect_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.ENOTCONN
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f.exception(), OSError)
+
+ def test_sock_accept(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_accept = unittest.mock.Mock()
+
+ f = self.loop.sock_accept(sock)
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock), self.loop._sock_accept.call_args[0])
+
+ def test__sock_accept(self):
+ f = futures.Future(loop=self.loop)
+
+ conn = unittest.mock.Mock()
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.return_value = conn, ('127.0.0.1', 1000)
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertTrue(f.done())
+ self.assertEqual((conn, ('127.0.0.1', 1000)), f.result())
+ self.assertEqual((False,), conn.setblocking.call_args[0])
+
+ def test__sock_accept_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertFalse(sock.accept.called)
+
+ def test__sock_accept_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._sock_accept(f, True, sock)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_accept_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = BlockingIOError
+
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop._sock_accept(f, False, sock)
+ self.assertEqual(
+ (10, self.loop._sock_accept, f, True, sock),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_accept_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.accept.side_effect = OSError()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertIs(err, f.exception())
+
+ def test_add_reader(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertIsNone(w)
+
+ def test_add_reader_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (reader, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(reader.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_add_reader_existing_writer(self):
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_remove_reader(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (None, None))
+ self.assertFalse(self.loop.remove_reader(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_reader_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_reader(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE, (None, writer)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_reader_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_reader(1))
+
+ def test_add_writer(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE, mask)
+ self.assertIsNone(r)
+ self.assertEqual(cb, w._callback)
+
+ def test_add_writer_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, writer))
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(writer.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(reader, r)
+ self.assertEqual(cb, w._callback)
+
+ def test_remove_writer(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, None))
+ self.assertFalse(self.loop.remove_writer(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_writer_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_writer(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_READ, (reader, None)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_writer_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_writer(1))
+
+ def test_process_events_read(self):
+ reader = unittest.mock.Mock()
+ reader._cancelled = False
+
+ self.loop._add_callback = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.assertTrue(self.loop._add_callback.called)
+ self.loop._add_callback.assert_called_with(reader)
+
+ def test_process_events_read_cancelled(self):
+ reader = unittest.mock.Mock()
+ reader.cancelled = True
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.loop.remove_reader.assert_called_with(1)
+
+ def test_process_events_write(self):
+ writer = unittest.mock.Mock()
+ writer._cancelled = False
+
+ self.loop._add_callback = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop._add_callback.assert_called_with(writer)
+
+ def test_process_events_write_cancelled(self):
+ writer = unittest.mock.Mock()
+ writer.cancelled = True
+ self.loop.remove_writer = unittest.mock.Mock()
+
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop.remove_writer.assert_called_with(1)
+
+
+class SelectorTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def test_ctor(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ self.assertIs(tr._loop, self.loop)
+ self.assertIs(tr._sock, self.sock)
+ self.assertIs(tr._sock_fd, 7)
+
+ def test_abort(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._force_close = unittest.mock.Mock()
+
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr.close()
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+ self.protocol.connection_lost(None)
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ def test_close_write_buffer(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._buffer.append(b'data')
+ tr.close()
+
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_force_close(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._buffer.append(b'1')
+ self.loop.add_reader(7, unittest.mock.sentinel)
+ self.loop.add_writer(7, unittest.mock.sentinel)
+ tr._force_close(None)
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(tr._buffer, collections.deque())
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+
+ # second close should not remove reader
+ tr._force_close(None)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ def test_fatal_error(self, m_exc):
+ exc = OSError()
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._force_close = unittest.mock.Mock()
+ tr._fatal_error(exc)
+
+ m_exc.assert_called_with('Fatal error for %s', tr)
+ tr._force_close.assert_called_with(exc)
+
+ def test_connection_lost(self):
+ exc = OSError()
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._call_connection_lost(exc)
+
+ self.protocol.connection_lost.assert_called_with(exc)
+ self.sock.close.assert_called_with()
+ self.assertIsNone(tr._sock)
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+
+class SelectorSocketTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock_fd = self.sock.fileno.return_value = 7
+
+ def test_ctor(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.loop.assert_reader(7, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+
+ _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+
+ def test_pause_resume(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+ tr.pause()
+ self.assertTrue(tr._paused)
+ self.assertFalse(7 in self.loop.readers)
+ tr.resume()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+
+ def test_read_ready(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recv.return_value = b'data'
+ transport._read_ready()
+
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_read_ready_eof(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close = unittest.mock.Mock()
+
+ self.sock.recv.return_value = b''
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ transport.close.assert_called_with()
+
+ def test_read_ready_eof_keep_open(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close = unittest.mock.Mock()
+
+ self.sock.recv.return_value = b''
+ self.protocol.eof_received.return_value = True
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ self.assertFalse(transport.close.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_tryagain(self, m_exc):
+ self.sock.recv.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_tryagain_interrupted(self, m_exc):
+ self.sock.recv.side_effect = InterruptedError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_conn_reset(self, m_exc):
+ err = self.sock.recv.side_effect = ConnectionResetError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._force_close = unittest.mock.Mock()
+ transport._read_ready()
+ transport._force_close.assert_called_with(err)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_err(self, m_exc):
+ err = self.sock.recv.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_write(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+ self.sock.send.assert_called_with(data)
+
+ def test_write_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(b'data')
+ transport.write(b'')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_buffer(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(b'data1')
+ transport.write(b'data2')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(collections.deque([b'data1', b'data2']),
+ transport._buffer)
+
+ def test_write_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'ta']), transport._buffer)
+
+ def test_write_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+ self.sock.fileno.return_value = 7
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_write_exception(self, m_log):
+ err = self.sock.send.side_effect = OSError()
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.write(data)
+ transport._fatal_error.assert_called_with(err)
+ transport._conn_lost = 1
+
+ self.sock.reset_mock()
+ transport.write(data)
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(transport._conn_lost, 2)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_write_str(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_write_ready(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.assertTrue(self.sock.send.called)
+ self.assertEqual(self.sock.send.call_args[0], (data,))
+ self.assertFalse(self.loop.writers)
+
+ def test_write_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.sock.send.assert_called_with(data)
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport._write_ready)
+
+ def test_write_ready_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'ta']), transport._buffer)
+
+ def test_write_ready_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_ready_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+
+ def test_write_ready_exception(self):
+ err = self.sock.send.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append(b'data')
+ transport._write_ready()
+ transport._fatal_error.assert_called_with(err)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_write_ready_exception_and_close(self, m_log):
+ self.sock.send.side_effect = OSError()
+ remove_writer = self.loop.remove_writer = unittest.mock.Mock()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close()
+ transport._buffer.append(b'data')
+ transport._write_ready()
+ remove_writer.assert_called_with(self.sock_fd)
+
+ def test_write_eof(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.sock.send.side_effect = BlockingIOError
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertEqual(tr._buffer, collections.deque([b'data']))
+ self.assertTrue(tr._eof)
+ self.assertFalse(self.sock.shutdown.called)
+ self.sock.send.side_effect = lambda _: 4
+ tr._write_ready()
+ self.sock.send.assert_called_with(b'data')
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class SelectorSslTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+ self.sslsock = unittest.mock.Mock()
+ self.sslsock.fileno.return_value = 1
+ self.sslcontext = unittest.mock.Mock()
+ self.sslcontext.wrap_socket.return_value = self.sslsock
+
+ def _make_one(self, create_waiter=None):
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ self.sock.reset_mock()
+ self.sslsock.reset_mock()
+ self.sslcontext.reset_mock()
+ self.loop.reset_counters()
+ return transport
+
+ def test_on_handshake(self):
+ waiter = futures.Future(loop=self.loop)
+ tr = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext,
+ waiter=waiter)
+ self.assertTrue(self.sslsock.do_handshake.called)
+ self.loop.assert_reader(1, tr._on_ready)
+ self.loop.assert_writer(1, tr._on_ready)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(waiter.result())
+
+ def test_on_handshake_reader_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._on_handshake()
+ self.loop.assert_reader(1, transport._on_handshake)
+
+ def test_on_handshake_writer_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._on_handshake()
+ self.loop.assert_writer(1, transport._on_handshake)
+
+ def test_on_handshake_exc(self):
+ exc = ValueError()
+ self.sslsock.do_handshake.side_effect = exc
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._waiter = futures.Future(loop=self.loop)
+ transport._on_handshake()
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(transport._waiter.done())
+ self.assertIs(exc, transport._waiter.exception())
+
+ def test_on_handshake_base_exc(self):
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._waiter = futures.Future(loop=self.loop)
+ exc = BaseException()
+ self.sslsock.do_handshake.side_effect = exc
+ self.assertRaises(BaseException, transport._on_handshake)
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(transport._waiter.done())
+ self.assertIs(exc, transport._waiter.exception())
+
+ def test_pause_resume(self):
+ tr = self._make_one()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._on_ready)
+ tr.pause()
+ self.assertTrue(tr._paused)
+ self.assertFalse(1 in self.loop.readers)
+ tr.resume()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._on_ready)
+
+ def test_write_no_data(self):
+ transport = self._make_one()
+ transport._buffer.append(b'data')
+ transport.write(b'')
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_str(self):
+ transport = self._make_one()
+ self.assertRaises(AssertionError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = self._make_one()
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_write_exception(self, m_log):
+ transport = self._make_one()
+ transport._conn_lost = 1
+ transport.write(b'data')
+ self.assertEqual(transport._buffer, collections.deque())
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_on_ready_recv(self):
+ self.sslsock.recv.return_value = b'data'
+ transport = self._make_one()
+ transport._on_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertEqual((b'data',), self.protocol.data_received.call_args[0])
+
+ def test_on_ready_recv_eof(self):
+ self.sslsock.recv.return_value = b''
+ transport = self._make_one()
+ transport.close = unittest.mock.Mock()
+ transport._on_ready()
+ transport.close.assert_called_with()
+ self.protocol.eof_received.assert_called_with()
+
+ def test_on_ready_recv_conn_reset(self):
+ err = self.sslsock.recv.side_effect = ConnectionResetError()
+ transport = self._make_one()
+ transport._force_close = unittest.mock.Mock()
+ transport._on_ready()
+ transport._force_close.assert_called_with(err)
+
+ def test_on_ready_recv_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ transport = self._make_one()
+ transport._on_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = ssl.SSLWantWriteError
+ transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = BlockingIOError
+ transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = InterruptedError
+ transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ def test_on_ready_recv_exc(self):
+ err = self.sslsock.recv.side_effect = OSError()
+ transport = self._make_one()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._on_ready()
+ transport._fatal_error.assert_called_with(err)
+
+ def test_on_ready_send(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+ transport._on_ready()
+ self.assertEqual(collections.deque(), transport._buffer)
+ self.assertTrue(self.sslsock.send.called)
+
+ def test_on_ready_send_none(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 0
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+
+ def test_on_ready_send_partial(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'ta1data2']), transport._buffer)
+
+ def test_on_ready_send_closing_partial(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertFalse(self.sslsock.close.called)
+
+ def test_on_ready_send_closing(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = collections.deque([b'data'])
+ transport._on_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_on_ready_send_closing_empty_buffer(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = collections.deque()
+ transport._on_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_on_ready_send_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+
+ self.sslsock.send.side_effect = ssl.SSLWantReadError
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ self.sslsock.send.side_effect = ssl.SSLWantWriteError
+ transport._on_ready()
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ self.sslsock.send.side_effect = BlockingIOError()
+ transport._on_ready()
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_on_ready_send_exc(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ err = self.sslsock.send.side_effect = OSError()
+
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+ transport._fatal_error = unittest.mock.Mock()
+ transport._on_ready()
+ transport._fatal_error.assert_called_with(err)
+ self.assertEqual(collections.deque(), transport._buffer)
+
+ def test_write_eof(self):
+ tr = self._make_one()
+ self.assertFalse(tr.can_write_eof())
+ self.assertRaises(NotImplementedError, tr.write_eof)
+
+ def test_close(self):
+ tr = self._make_one()
+ tr.close()
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+
+ @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support')
+ def test_server_hostname(self):
+ _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext,
+ server_hostname='localhost')
+ self.sslcontext.wrap_socket.assert_called_with(
+ self.sock, do_handshake_on_connect=False, server_side=False,
+ server_hostname='localhost')
+
+
+class SelectorDatagramTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(DatagramProtocol)
+ self.sock = unittest.mock.Mock(spec_set=socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def test_read_ready(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234))
+ transport._read_ready()
+
+ self.protocol.datagram_received.assert_called_with(
+ b'data', ('0.0.0.0', 1234))
+
+ def test_read_ready_tryagain(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.side_effect = BlockingIOError
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_read_ready_err(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ err = self.sock.recvfrom.side_effect = OSError()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_sendto(self):
+ data = b'data'
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+ def test_sendto_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+ transport.sendto(b'', ())
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ def test_sendto_buffer(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport.sendto(b'data2', ('0.0.0.0', 12345))
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+
+ def test_sendto_tryagain(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 12345))
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_sendto_exception(self, m_log):
+ data = b'data'
+ err = self.sock.sendto.side_effect = OSError()
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertTrue(transport._fatal_error.called)
+ transport._fatal_error.assert_called_with(err)
+ transport._conn_lost = 1
+
+ transport._address = ('123',)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_sendto_connection_refused(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertEqual(transport._conn_lost, 0)
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_connection_refused_connected(self):
+ data = b'data'
+
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data)
+
+ self.assertTrue(transport._fatal_error.called)
+
+ def test_sendto_str(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.sendto, 'str', ())
+
+ def test_sendto_connected_addr(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ self.assertRaises(
+ AssertionError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+ def test_sendto_closing(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, address=(1,))
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.sendto(b'data', (1,))
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_sendto_ready(self):
+ data = b'data'
+ self.sock.sendto.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((data, ('0.0.0.0', 12345)))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345)))
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append((data, ()))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.sock.sendto.assert_called_with(data, ())
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_sendto_ready_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertFalse(self.sock.sendto.called)
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_tryagain(self):
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.extend([(b'data1', ()), (b'data2', ())])
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data1', ()), (b'data2', ())],
+ list(transport._buffer))
+
+ def test_sendto_ready_exception(self):
+ err = self.sock.sendto.side_effect = OSError()
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_sendto_ready_connection_refused(self):
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_ready_connection_refused_connection(self):
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertTrue(transport._fatal_error.called)
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ def test_fatal_error_connected(self, m_exc):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ err = ConnectionRefusedError()
+ transport._fatal_error(err)
+ self.protocol.connection_refused.assert_called_with(err)
+ m_exc.assert_called_with('Fatal error for %s', transport)
diff --git a/Lib/test/test_asyncio/test_selectors.py b/Lib/test/test_asyncio/test_selectors.py
new file mode 100644
index 0000000..2f7dc69
--- /dev/null
+++ b/Lib/test/test_asyncio/test_selectors.py
@@ -0,0 +1,145 @@
+"""Tests for selectors.py."""
+
+import unittest
+import unittest.mock
+
+from asyncio import selectors
+
+
+class FakeSelector(selectors.BaseSelector):
+ """Trivial non-abstract subclass of BaseSelector."""
+
+ def select(self, timeout=None):
+ raise NotImplementedError
+
+
+class BaseSelectorTests(unittest.TestCase):
+
+ def test_fileobj_to_fd(self):
+ self.assertEqual(10, selectors._fileobj_to_fd(10))
+
+ f = unittest.mock.Mock()
+ f.fileno.return_value = 10
+ self.assertEqual(10, selectors._fileobj_to_fd(f))
+
+ f.fileno.side_effect = AttributeError
+ self.assertRaises(ValueError, selectors._fileobj_to_fd, f)
+
+ def test_selector_key_repr(self):
+ key = selectors.SelectorKey(10, 10, selectors.EVENT_READ, None)
+ self.assertEqual(
+ "SelectorKey(fileobj=10, fd=10, events=1, data=None)", repr(key))
+
+ def test_register(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ)
+ self.assertIsInstance(key, selectors.SelectorKey)
+ self.assertEqual(key.fd, 10)
+ self.assertIs(key, s._fd_to_key[10])
+
+ def test_register_unknown_event(self):
+ s = FakeSelector()
+ self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999)
+
+ def test_register_already_registered(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ s.register(fobj, selectors.EVENT_READ)
+ self.assertRaises(KeyError, s.register, fobj, selectors.EVENT_READ)
+
+ def test_unregister(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ s.register(fobj, selectors.EVENT_READ)
+ s.unregister(fobj)
+ self.assertFalse(s._fd_to_key)
+
+ def test_unregister_unknown(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ self.assertRaises(KeyError, s.unregister, fobj)
+
+ def test_modify_unknown(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ self.assertRaises(KeyError, s.modify, fobj, 1)
+
+ def test_modify(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ)
+ key2 = s.modify(fobj, selectors.EVENT_WRITE)
+ self.assertNotEqual(key.events, key2.events)
+ self.assertEqual(
+ selectors.SelectorKey(fobj, 10, selectors.EVENT_WRITE, None),
+ s.get_key(fobj))
+
+ def test_modify_data(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ d1 = object()
+ d2 = object()
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ, d1)
+ key2 = s.modify(fobj, selectors.EVENT_READ, d2)
+ self.assertEqual(key.events, key2.events)
+ self.assertNotEqual(key.data, key2.data)
+ self.assertEqual(
+ selectors.SelectorKey(fobj, 10, selectors.EVENT_READ, d2),
+ s.get_key(fobj))
+
+ def test_modify_same(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ data = object()
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ, data)
+ key2 = s.modify(fobj, selectors.EVENT_READ, data)
+ self.assertIs(key, key2)
+
+ def test_select(self):
+ s = FakeSelector()
+ self.assertRaises(NotImplementedError, s.select)
+
+ def test_close(self):
+ s = FakeSelector()
+ s.register(1, selectors.EVENT_READ)
+
+ s.close()
+ self.assertFalse(s._fd_to_key)
+
+ def test_context_manager(self):
+ s = FakeSelector()
+
+ with s as sel:
+ sel.register(1, selectors.EVENT_READ)
+
+ self.assertFalse(s._fd_to_key)
+
+ def test_key_from_fd(self):
+ s = FakeSelector()
+ key = s.register(1, selectors.EVENT_READ)
+
+ self.assertIs(key, s._key_from_fd(1))
+ self.assertIsNone(s._key_from_fd(10))
+
+ if hasattr(selectors.DefaultSelector, 'fileno'):
+ def test_fileno(self):
+ self.assertIsInstance(selectors.DefaultSelector().fileno(), int)
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
new file mode 100644
index 0000000..011a09d
--- /dev/null
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -0,0 +1,361 @@
+"""Tests for streams.py."""
+
+import gc
+import ssl
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import streams
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class StreamReaderTests(unittest.TestCase):
+
+ DATA = b'line1\nline2\nline3\n'
+
+ def setUp(self):
+ self.loop = events.new_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+
+ @unittest.mock.patch('asyncio.streams.events')
+ def test_ctor_global_loop(self, m_events):
+ stream = streams.StreamReader()
+ self.assertIs(stream.loop, m_events.get_event_loop.return_value)
+
+ def test_open_connection(self):
+ with test_utils.run_test_server() as httpd:
+ f = streams.open_connection(*httpd.address, loop=self.loop)
+ reader, writer = self.loop.run_until_complete(f)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_open_connection_no_loop_ssl(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ try:
+ events.set_event_loop(self.loop)
+ f = streams.open_connection(*httpd.address,
+ ssl=test_utils.dummy_ssl_context())
+ reader, writer = self.loop.run_until_complete(f)
+ finally:
+ events.set_event_loop(None)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
+
+ def test_open_connection_error(self):
+ with test_utils.run_test_server() as httpd:
+ f = streams.open_connection(*httpd.address, loop=self.loop)
+ reader, writer = self.loop.run_until_complete(f)
+ writer._protocol.connection_lost(ZeroDivisionError())
+ f = reader.read()
+ with self.assertRaises(ZeroDivisionError):
+ self.loop.run_until_complete(f)
+
+ writer.close()
+ test_utils.run_briefly(self.loop)
+
+ def test_feed_empty_data(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ stream.feed_data(b'')
+ self.assertEqual(0, stream.byte_count)
+
+ def test_feed_data_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ stream.feed_data(self.DATA)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_read_zero(self):
+ # Read zero bytes.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.read(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_read(self):
+ # Read bytes.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(30), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_line_breaks(self):
+ # Read bytes without line breaks.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line1')
+ stream.feed_data(b'line2')
+
+ data = self.loop.run_until_complete(stream.read(5))
+
+ self.assertEqual(b'line1', data)
+ self.assertEqual(5, stream.byte_count)
+
+ def test_read_eof(self):
+ # Read bytes, stop at eof.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(1024), loop=self.loop)
+
+ def cb():
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'', data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_until_eof(self):
+ # Read all bytes until eof.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(-1), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk1\n')
+ stream.feed_data(b'chunk2')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+
+ self.assertEqual(b'chunk1\nchunk2', data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.read(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.read(2))
+
+ def test_readline(self):
+ # Read one line.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'chunk1 ')
+ read_task = tasks.Task(stream.readline(), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk2 ')
+ stream.feed_data(b'chunk3 ')
+ stream.feed_data(b'\n chunk4')
+ self.loop.call_soon(cb)
+
+ line = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
+ self.assertEqual(len(b'\n chunk4')-1, stream.byte_count)
+
+ def test_readline_limit_with_existing_data(self):
+ stream = streams.StreamReader(3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1\nline2\n')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'line2\n'], list(stream.buffer))
+
+ stream = streams.StreamReader(3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1')
+ stream.feed_data(b'li')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'li'], list(stream.buffer))
+ self.assertEqual(2, stream.byte_count)
+
+ def test_readline_limit(self):
+ stream = streams.StreamReader(7, loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk1')
+ stream.feed_data(b'chunk2')
+ stream.feed_data(b'chunk3\n')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'chunk3\n'], list(stream.buffer))
+ self.assertEqual(7, stream.byte_count)
+
+ def test_readline_line_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA[:6])
+ stream.feed_data(self.DATA[6:])
+
+ line = self.loop.run_until_complete(stream.readline())
+
+ self.assertEqual(b'line1\n', line)
+ self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count)
+
+ def test_readline_eof(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'some data')
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'some data', line)
+
+ def test_readline_empty_eof(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'', line)
+
+ def test_readline_read_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ self.loop.run_until_complete(stream.readline())
+
+ data = self.loop.run_until_complete(stream.read(7))
+
+ self.assertEqual(b'line2\nl', data)
+ self.assertEqual(
+ len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
+ stream.byte_count)
+
+ def test_readline_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'line\n', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+
+ def test_readexactly_zero_or_less(self):
+ # Read exact number of bytes (zero or less).
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.readexactly(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ data = self.loop.run_until_complete(stream.readexactly(-1))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_readexactly(self):
+ # Read exact number of bytes.
+ stream = streams.StreamReader(loop=self.loop)
+
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA + self.DATA, data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_readexactly_eof(self):
+ # Read exact number of bytes (eof).
+ stream = streams.StreamReader(loop=self.loop)
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream.byte_count)
+
+ def test_readexactly_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readexactly(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readexactly(2))
+
+ def test_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ self.assertIsNone(stream.exception())
+
+ exc = ValueError()
+ stream.set_exception(exc)
+ self.assertIs(stream.exception(), exc)
+
+ def test_exception_waiter(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ @tasks.coroutine
+ def set_err():
+ stream.set_exception(ValueError())
+
+ @tasks.coroutine
+ def readline():
+ yield from stream.readline()
+
+ t1 = tasks.Task(stream.readline(), loop=self.loop)
+ t2 = tasks.Task(set_err(), loop=self.loop)
+
+ self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop))
+
+ self.assertRaises(ValueError, t1.result)
+
+ def test_exception_cancel(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ @tasks.coroutine
+ def read_a_line():
+ yield from stream.readline()
+
+ t = tasks.Task(read_a_line(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ t.cancel()
+ test_utils.run_briefly(self.loop)
+ # The following line fails if set_exception() isn't careful.
+ stream.set_exception(RuntimeError('message'))
+ test_utils.run_briefly(self.loop)
+ self.assertIs(stream.waiter, None)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
new file mode 100644
index 0000000..57fb053
--- /dev/null
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -0,0 +1,1518 @@
+"""Tests for tasks.py."""
+
+import gc
+import unittest
+import unittest.mock
+from unittest.mock import Mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class Dummy:
+
+ def __repr__(self):
+ return 'Dummy()'
+
+ def __call__(self, *args):
+ pass
+
+
+class TaskTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ gc.collect()
+
+ def test_task_class(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t = tasks.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = events.new_event_loop()
+ t = tasks.Task(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.close()
+
+ def test_async_coroutine(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t = tasks.async(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = events.new_event_loop()
+ t = tasks.async(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.close()
+
+ def test_async_future(self):
+ f_orig = futures.Future(loop=self.loop)
+ f_orig.set_result('ko')
+
+ f = tasks.async(f_orig)
+ self.loop.run_until_complete(f)
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 'ko')
+ self.assertIs(f, f_orig)
+
+ loop = events.new_event_loop()
+
+ with self.assertRaises(ValueError):
+ f = tasks.async(f_orig, loop=loop)
+
+ loop.close()
+
+ f = tasks.async(f_orig, loop=self.loop)
+ self.assertIs(f, f_orig)
+
+ def test_async_task(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t_orig = tasks.Task(notmuch(), loop=self.loop)
+ t = tasks.async(t_orig)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t, t_orig)
+
+ loop = events.new_event_loop()
+
+ with self.assertRaises(ValueError):
+ t = tasks.async(t_orig, loop=loop)
+
+ loop.close()
+
+ t = tasks.async(t_orig, loop=self.loop)
+ self.assertIs(t, t_orig)
+
+ def test_async_neither(self):
+ with self.assertRaises(TypeError):
+ tasks.async('ok')
+
+ def test_task_repr(self):
+ @tasks.coroutine
+ def notmuch():
+ yield from []
+ return 'abc'
+
+ t = tasks.Task(notmuch(), loop=self.loop)
+ t.add_done_callback(Dummy())
+ self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>')
+ t.cancel() # Does not take immediate effect!
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>')
+ self.assertRaises(futures.CancelledError,
+ self.loop.run_until_complete, t)
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>')
+ t = tasks.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>")
+
+ def test_task_repr_custom(self):
+ @tasks.coroutine
+ def coro():
+ pass
+
+ class T(futures.Future):
+ def __repr__(self):
+ return 'T[]'
+
+ class MyTask(tasks.Task, T):
+ def __repr__(self):
+ return super().__repr__()
+
+ gen = coro()
+ t = MyTask(gen, loop=self.loop)
+ self.assertEqual(repr(t), 'T[](<coro>)')
+ gen.close()
+
+ def test_task_basics(self):
+ @tasks.coroutine
+ def outer():
+ a = yield from inner1()
+ b = yield from inner2()
+ return a+b
+
+ @tasks.coroutine
+ def inner1():
+ return 42
+
+ @tasks.coroutine
+ def inner2():
+ return 1000
+
+ t = outer()
+ self.assertEqual(self.loop.run_until_complete(t), 1042)
+
+ def test_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def task():
+ yield from tasks.sleep(10.0, loop=loop)
+ return 12
+
+ t = tasks.Task(task(), loop=loop)
+ loop.call_soon(t.cancel)
+ with self.assertRaises(futures.CancelledError):
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_yield(self):
+ @tasks.coroutine
+ def task():
+ yield
+ yield
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start coro
+ t.cancel()
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_inner_future(self):
+ f = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start task
+ f.cancel()
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(t)
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_both_task_and_inner_future(self):
+ f = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+
+ f.cancel()
+ t.cancel()
+
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(t)
+
+ self.assertTrue(t.done())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_task_catching(self):
+ fut1 = futures.Future(loop=self.loop)
+ fut2 = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except futures.CancelledError:
+ return 42
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_task_ignoring(self):
+ fut1 = futures.Future(loop=self.loop)
+ fut2 = futures.Future(loop=self.loop)
+ fut3 = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except futures.CancelledError:
+ pass
+ res = yield from fut3
+ return res
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut3) # White-box test.
+ fut3.set_result(42)
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(fut3.cancelled())
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_current_task(self):
+ loop = events.new_event_loop()
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def task():
+ t.cancel()
+ self.assertTrue(t._must_cancel) # White-box test.
+ # The sleep should be cancelled immediately.
+ yield from tasks.sleep(100, loop=loop)
+ return 12
+
+ t = tasks.Task(task(), loop=loop)
+ self.assertRaises(
+ futures.CancelledError, loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertFalse(t._must_cancel) # White-box test.
+ self.assertFalse(t.cancel())
+
+ def test_stop_while_run_in_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.3, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ x = 0
+ waiters = []
+
+ @tasks.coroutine
+ def task():
+ nonlocal x
+ while x < 10:
+ waiters.append(tasks.sleep(0.1, loop=loop))
+ yield from waiters[-1]
+ x += 1
+ if x == 2:
+ loop.stop()
+
+ t = tasks.Task(task(), loop=loop)
+ self.assertRaises(
+ RuntimeError, loop.run_until_complete, t)
+ self.assertFalse(t.done())
+ self.assertEqual(x, 2)
+ self.assertAlmostEqual(0.3, loop.time())
+
+ # close generators
+ for w in waiters:
+ w.close()
+
+ def test_wait_for(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.4, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def foo():
+ yield from tasks.sleep(0.2, loop=loop)
+ return 'done'
+
+ fut = tasks.Task(foo(), loop=loop)
+
+ with self.assertRaises(futures.TimeoutError):
+ loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop))
+
+ self.assertFalse(fut.done())
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # wait for result
+ res = loop.run_until_complete(
+ tasks.wait_for(fut, 0.3, loop=loop))
+ self.assertEqual(res, 'done')
+ self.assertAlmostEqual(0.2, loop.time())
+
+ def test_wait_for_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def foo():
+ yield from tasks.sleep(0.2, loop=loop)
+ return 'done'
+
+ events.set_event_loop(loop)
+ try:
+ fut = tasks.Task(foo(), loop=loop)
+ with self.assertRaises(futures.TimeoutError):
+ loop.run_until_complete(tasks.wait_for(fut, 0.01))
+ finally:
+ events.set_event_loop(None)
+
+ self.assertAlmostEqual(0.01, loop.time())
+ self.assertFalse(fut.done())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(fut)
+
+ def test_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], loop=loop)
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertEqual(res, 42)
+ self.assertAlmostEqual(0.15, loop.time())
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertEqual(res, 42)
+
+ def test_wait_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0
+ self.assertAlmostEqual(0.015, when)
+ yield 0.015
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a])
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ events.set_event_loop(loop)
+ try:
+ res = loop.run_until_complete(
+ tasks.Task(foo(), loop=loop))
+ finally:
+ events.set_event_loop(None)
+
+ self.assertEqual(res, 42)
+
+ def test_wait_errors(self):
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete,
+ tasks.wait(set(), loop=self.loop))
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete,
+ tasks.wait([tasks.sleep(10.0, loop=self.loop)],
+ return_when=-1, loop=self.loop))
+
+ def test_wait_first_completed(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertFalse(a.done())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_really_done(self):
+ # there is possibility that some tasks in the pending list
+ # became done but their callbacks haven't all been called yet
+
+ @tasks.coroutine
+ def coro1():
+ yield
+
+ @tasks.coroutine
+ def coro2():
+ yield
+ yield
+
+ a = tasks.Task(coro1(), loop=self.loop)
+ b = tasks.Task(coro2(), loop=self.loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED,
+ loop=self.loop),
+ loop=self.loop)
+
+ done, pending = self.loop.run_until_complete(task)
+ self.assertEqual({a, b}, done)
+ self.assertTrue(a.done())
+ self.assertIsNone(a.result())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+
+ def test_wait_first_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ # first_exception, task already has exception
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def exc():
+ raise ZeroDivisionError('err')
+
+ b = tasks.Task(exc(), loop=loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_first_exception_in_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ # first_exception, exception during waiting
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def exc():
+ yield from tasks.sleep(0.01, loop=loop)
+ raise ZeroDivisionError('err')
+
+ b = tasks.Task(exc(), loop=loop)
+ task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION,
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_with_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(0.15, loop=loop)
+ raise ZeroDivisionError('really')
+
+ b = tasks.Task(sleeper(), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], loop=loop)
+ self.assertEqual(len(done), 2)
+ self.assertEqual(pending, set())
+ errors = set(f for f in done if f.exception() is not None)
+ self.assertEqual(len(errors), 1)
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_wait_with_timeout(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.11, when)
+ yield 0.11
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], timeout=0.11,
+ loop=loop)
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.11, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_concurrent_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ done, pending = loop.run_until_complete(
+ tasks.wait([b, a], timeout=0.1, loop=loop))
+
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_as_completed(self):
+
+ def gen():
+ yield 0
+ yield 0
+ yield 0.01
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+ completed = set()
+ time_shifted = False
+
+ @tasks.coroutine
+ def sleeper(dt, x):
+ nonlocal time_shifted
+ yield from tasks.sleep(dt, loop=loop)
+ completed.add(x)
+ if not time_shifted and 'a' in completed and 'b' in completed:
+ time_shifted = True
+ loop.advance_time(0.14)
+ return x
+
+ a = sleeper(0.01, 'a')
+ b = sleeper(0.01, 'b')
+ c = sleeper(0.15, 'c')
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([b, c, a], loop=loop):
+ values.append((yield from f))
+ return values
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertTrue('a' in res[:2])
+ self.assertTrue('b' in res[:2])
+ self.assertEqual(res[2], 'c')
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_as_completed_with_timeout(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.12, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.12, when)
+ yield 0.02
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.1, 'a', loop=loop)
+ b = tasks.sleep(0.15, 'b', loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([a, b], timeout=0.12, loop=loop):
+ try:
+ v = yield from f
+ values.append((1, v))
+ except futures.TimeoutError as exc:
+ values.append((2, exc))
+ return values
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertEqual(len(res), 2, res)
+ self.assertEqual(res[0], (1, 'a'))
+ self.assertEqual(res[1][0], 2)
+ self.assertTrue(isinstance(res[1][1], futures.TimeoutError))
+ self.assertAlmostEqual(0.12, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_as_completed_reverse_wait(self):
+
+ def gen():
+ yield 0
+ yield 0.05
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.05, 'a', loop=loop)
+ b = tasks.sleep(0.10, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(tasks.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+
+ x = loop.run_until_complete(futs[1])
+ self.assertEqual(x, 'a')
+ self.assertAlmostEqual(0.05, loop.time())
+ loop.advance_time(0.05)
+ y = loop.run_until_complete(futs[0])
+ self.assertEqual(y, 'b')
+ self.assertAlmostEqual(0.10, loop.time())
+
+ def test_as_completed_concurrent(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0
+ self.assertAlmostEqual(0.05, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.05, 'a', loop=loop)
+ b = tasks.sleep(0.05, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(tasks.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+ waiter = tasks.wait(futs, loop=loop)
+ done, pending = loop.run_until_complete(waiter)
+ self.assertEqual(set(f.result() for f in done), {'a', 'b'})
+
+ def test_sleep(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0.05
+ self.assertAlmostEqual(0.1, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def sleeper(dt, arg):
+ yield from tasks.sleep(dt/2, loop=loop)
+ res = yield from tasks.sleep(dt/2, arg, loop=loop)
+ return res
+
+ t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop)
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'yeah')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_sleep_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop),
+ loop=loop)
+
+ handle = None
+ orig_call_later = loop.call_later
+
+ def call_later(self, delay, callback, *args):
+ nonlocal handle
+ handle = orig_call_later(self, delay, callback, *args)
+ return handle
+
+ loop.call_later = call_later
+ test_utils.run_briefly(loop)
+
+ self.assertFalse(handle._cancelled)
+
+ t.cancel()
+ test_utils.run_briefly(loop)
+ self.assertTrue(handle._cancelled)
+
+ def test_task_cancel_sleeping_task(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(5000, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ sleepfut = None
+
+ @tasks.coroutine
+ def sleep(dt):
+ nonlocal sleepfut
+ sleepfut = tasks.sleep(dt, loop=loop)
+ yield from sleepfut
+
+ @tasks.coroutine
+ def doit():
+ sleeper = tasks.Task(sleep(5000), loop=loop)
+ loop.call_later(0.1, sleeper.cancel)
+ try:
+ yield from sleeper
+ except futures.CancelledError:
+ return 'cancelled'
+ else:
+ return 'slept in'
+
+ doer = doit()
+ self.assertEqual(loop.run_until_complete(doer), 'cancelled')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_task_cancel_waiter_future(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def coro():
+ yield from fut
+
+ task = tasks.Task(coro(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(task._fut_waiter, fut)
+
+ task.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, task)
+ self.assertIsNone(task._fut_waiter)
+ self.assertTrue(fut.cancelled())
+
+ def test_step_in_completed_task(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ko'
+
+ gen = notmuch()
+ task = tasks.Task(gen, loop=self.loop)
+ task.set_result('ok')
+
+ self.assertRaises(AssertionError, task._step)
+ gen.close()
+
+ def test_step_result(self):
+ @tasks.coroutine
+ def notmuch():
+ yield None
+ yield 1
+ return 'ko'
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, notmuch())
+
+ def test_step_result_future(self):
+ # If coroutine returns future, task waits on this future.
+
+ class Fut(futures.Future):
+ def __init__(self, *args, **kwds):
+ self.cb_added = False
+ super().__init__(*args, **kwds)
+
+ def add_done_callback(self, fn):
+ self.cb_added = True
+ super().add_done_callback(fn)
+
+ fut = Fut(loop=self.loop)
+ result = None
+
+ @tasks.coroutine
+ def wait_for_future():
+ nonlocal result
+ result = yield from fut
+
+ t = tasks.Task(wait_for_future(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(fut.cb_added)
+
+ res = object()
+ fut.set_result(res)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(res, result)
+ self.assertTrue(t.done())
+ self.assertIsNone(t.result())
+
+ def test_step_with_baseexception(self):
+ @tasks.coroutine
+ def notmutch():
+ raise BaseException()
+
+ task = tasks.Task(notmutch(), loop=self.loop)
+ self.assertRaises(BaseException, task._step)
+
+ self.assertTrue(task.done())
+ self.assertIsInstance(task.exception(), BaseException)
+
+ def test_baseexception_during_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(10, loop=loop)
+
+ base_exc = BaseException()
+
+ @tasks.coroutine
+ def notmutch():
+ try:
+ yield from sleeper()
+ except futures.CancelledError:
+ raise base_exc
+
+ task = tasks.Task(notmutch(), loop=loop)
+ test_utils.run_briefly(loop)
+
+ task.cancel()
+ self.assertFalse(task.done())
+
+ self.assertRaises(BaseException, test_utils.run_briefly, loop)
+
+ self.assertTrue(task.done())
+ self.assertFalse(task.cancelled())
+ self.assertIs(task.exception(), base_exc)
+
+ def test_iscoroutinefunction(self):
+ def fn():
+ pass
+
+ self.assertFalse(tasks.iscoroutinefunction(fn))
+
+ def fn1():
+ yield
+ self.assertFalse(tasks.iscoroutinefunction(fn1))
+
+ @tasks.coroutine
+ def fn2():
+ yield
+ self.assertTrue(tasks.iscoroutinefunction(fn2))
+
+ def test_yield_vs_yield_from(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def wait_for_future():
+ yield fut
+
+ task = wait_for_future()
+ with self.assertRaises(RuntimeError):
+ self.loop.run_until_complete(task)
+
+ self.assertFalse(fut.done())
+
+ def test_yield_vs_yield_from_generator(self):
+ @tasks.coroutine
+ def coro():
+ yield
+
+ @tasks.coroutine
+ def wait_for_future():
+ gen = coro()
+ try:
+ yield gen
+ finally:
+ gen.close()
+
+ task = wait_for_future()
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_coroutine_non_gen_function(self):
+ @tasks.coroutine
+ def func():
+ return 'test'
+
+ self.assertTrue(tasks.iscoroutinefunction(func))
+
+ coro = func()
+ self.assertTrue(tasks.iscoroutine(coro))
+
+ res = self.loop.run_until_complete(coro)
+ self.assertEqual(res, 'test')
+
+ def test_coroutine_non_gen_function_return_future(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def func():
+ return fut
+
+ @tasks.coroutine
+ def coro():
+ fut.set_result('test')
+
+ t1 = tasks.Task(func(), loop=self.loop)
+ t2 = tasks.Task(coro(), loop=self.loop)
+ res = self.loop.run_until_complete(t1)
+ self.assertEqual(res, 'test')
+ self.assertIsNone(t2.result())
+
+ # Some thorough tests for cancellation propagation through
+ # coroutines, tasks and wait().
+
+ def test_yield_future_passes_cancel(self):
+ # Cancelling outer() cancels inner() cancels waiter.
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ try:
+ yield from waiter
+ except futures.CancelledError:
+ proof += 1
+ raise
+ else:
+ self.fail('got past sleep() in inner()')
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ try:
+ yield from inner()
+ except futures.CancelledError:
+ proof += 100 # Expect this path.
+ else:
+ proof += 10
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.loop.run_until_complete(f)
+ self.assertEqual(proof, 101)
+ self.assertTrue(waiter.cancelled())
+
+ def test_yield_wait_does_not_shield_cancel(self):
+ # Cancelling outer() makes wait() return early, leaves inner()
+ # running.
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ d, p = yield from tasks.wait([inner()], loop=self.loop)
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_result(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ inner.set_result(42)
+ res = self.loop.run_until_complete(outer)
+ self.assertEqual(res, 42)
+
+ def test_shield_exception(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ test_utils.run_briefly(self.loop)
+ exc = RuntimeError('expected')
+ inner.set_exception(exc)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(outer.exception(), exc)
+
+ def test_shield_cancel(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ test_utils.run_briefly(self.loop)
+ inner.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+
+ def test_shield_shortcut(self):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ res = self.loop.run_until_complete(tasks.shield(fut))
+ self.assertEqual(res, 42)
+
+ def test_shield_effect(self):
+ # Cancelling outer() does not affect inner().
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ yield from tasks.shield(inner(), loop=self.loop)
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_gather(self):
+ child1 = futures.Future(loop=self.loop)
+ child2 = futures.Future(loop=self.loop)
+ parent = tasks.gather(child1, child2, loop=self.loop)
+ outer = tasks.shield(parent, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(parent.result(), [1, 2])
+
+ def test_gather_shield(self):
+ child1 = futures.Future(loop=self.loop)
+ child2 = futures.Future(loop=self.loop)
+ inner1 = tasks.shield(child1, loop=self.loop)
+ inner2 = tasks.shield(child2, loop=self.loop)
+ parent = tasks.gather(inner1, inner2, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ parent.cancel()
+ # This should cancel inner1 and inner2 but bot child1 and child2.
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(parent.exception(), futures.CancelledError)
+ self.assertTrue(inner1.cancelled())
+ self.assertTrue(inner2.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+
+
+class GatherTestsBase:
+
+ def setUp(self):
+ self.one_loop = test_utils.TestLoop()
+ self.other_loop = test_utils.TestLoop()
+
+ def tearDown(self):
+ self.one_loop.close()
+ self.other_loop.close()
+
+ def _run_loop(self, loop):
+ while loop._ready:
+ test_utils.run_briefly(loop)
+
+ def _check_success(self, **kwargs):
+ a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ b.set_result(1)
+ a.set_result(2)
+ self._run_loop(self.one_loop)
+ self.assertEqual(cb.called, False)
+ self.assertFalse(fut.done())
+ c.set_result(3)
+ self._run_loop(self.one_loop)
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [2, 1, 3])
+
+ def test_success(self):
+ self._check_success()
+ self._check_success(return_exceptions=False)
+
+ def test_result_exception_success(self):
+ self._check_success(return_exceptions=True)
+
+ def test_one_exception(self):
+ a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c, d, e))
+ cb = Mock()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ a.set_result(1)
+ b.set_exception(exc)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertIs(fut.exception(), exc)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+
+ def test_return_exceptions(self):
+ a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c, d),
+ return_exceptions=True)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ exc2 = RuntimeError()
+ b.set_result(1)
+ c.set_exception(exc)
+ a.set_result(3)
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_exception(exc2)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [3, 1, exc, exc2])
+
+
+class FutureGatherTests(GatherTestsBase, unittest.TestCase):
+
+ def wrap_futures(self, *futures):
+ return futures
+
+ def _check_empty_sequence(self, seq_or_iter):
+ events.set_event_loop(self.one_loop)
+ self.addCleanup(events.set_event_loop, None)
+ fut = tasks.gather(*seq_or_iter)
+ self.assertIsInstance(fut, futures.Future)
+ self.assertIs(fut._loop, self.one_loop)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ self.assertEqual(fut.result(), [])
+ fut = tasks.gather(*seq_or_iter, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+
+ def test_constructor_empty_sequence(self):
+ self._check_empty_sequence([])
+ self._check_empty_sequence(())
+ self._check_empty_sequence(set())
+ self._check_empty_sequence(iter(""))
+
+ def test_constructor_heterogenous_futures(self):
+ fut1 = futures.Future(loop=self.one_loop)
+ fut2 = futures.Future(loop=self.other_loop)
+ with self.assertRaises(ValueError):
+ tasks.gather(fut1, fut2)
+ with self.assertRaises(ValueError):
+ tasks.gather(fut1, loop=self.other_loop)
+
+ def test_constructor_homogenous_futures(self):
+ children = [futures.Future(loop=self.other_loop) for i in range(3)]
+ fut = tasks.gather(*children)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+ fut = tasks.gather(*children, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+
+ def test_one_cancellation(self):
+ a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)]
+ fut = tasks.gather(a, b, c, d, e)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ b.cancel()
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertFalse(fut.cancelled())
+ self.assertIsInstance(fut.exception(), futures.CancelledError)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+
+ def test_result_exception_one_cancellation(self):
+ a, b, c, d, e, f = [futures.Future(loop=self.one_loop)
+ for i in range(6)]
+ fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ zde = ZeroDivisionError()
+ b.set_exception(zde)
+ c.cancel()
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_result(3)
+ e.cancel()
+ rte = RuntimeError()
+ f.set_exception(rte)
+ res = self.one_loop.run_until_complete(fut)
+ self.assertIsInstance(res[2], futures.CancelledError)
+ self.assertIsInstance(res[4], futures.CancelledError)
+ res[2] = res[4] = None
+ self.assertEqual(res, [1, zde, None, 3, None, rte])
+ cb.assert_called_once_with(fut)
+
+
+class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ events.set_event_loop(self.one_loop)
+
+ def tearDown(self):
+ events.set_event_loop(None)
+ super().tearDown()
+
+ def wrap_futures(self, *futures):
+ coros = []
+ for fut in futures:
+ @tasks.coroutine
+ def coro(fut=fut):
+ return (yield from fut)
+ coros.append(coro())
+ return coros
+
+ def test_constructor_loop_selection(self):
+ @tasks.coroutine
+ def coro():
+ return 'abc'
+ gen1 = coro()
+ gen2 = coro()
+ fut = tasks.gather(gen1, gen2)
+ self.assertIs(fut._loop, self.one_loop)
+ gen1.close()
+ gen2.close()
+ gen3 = coro()
+ gen4 = coro()
+ fut = tasks.gather(gen3, gen4, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+ gen3.close()
+ gen4.close()
+
+ def test_cancellation_broadcast(self):
+ # Cancelling outer() cancels all children.
+ proof = 0
+ waiter = futures.Future(loop=self.one_loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ child1 = tasks.async(inner(), loop=self.one_loop)
+ child2 = tasks.async(inner(), loop=self.one_loop)
+ gatherer = None
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof, gatherer
+ gatherer = tasks.gather(child1, child2, loop=self.one_loop)
+ yield from gatherer
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ self.assertTrue(f.cancel())
+ with self.assertRaises(futures.CancelledError):
+ self.one_loop.run_until_complete(f)
+ self.assertFalse(gatherer.cancel())
+ self.assertTrue(waiter.cancelled())
+ self.assertTrue(child1.cancelled())
+ self.assertTrue(child2.cancelled())
+ test_utils.run_briefly(self.one_loop)
+ self.assertEqual(proof, 0)
+
+ def test_exception_marking(self):
+ # Test for the first line marked "Mark exception retrieved."
+
+ @tasks.coroutine
+ def inner(f):
+ yield from f
+ raise RuntimeError('should not be ignored')
+
+ a = futures.Future(loop=self.one_loop)
+ b = futures.Future(loop=self.one_loop)
+
+ @tasks.coroutine
+ def outer():
+ yield from tasks.gather(inner(a), inner(b), loop=self.one_loop)
+
+ f = tasks.async(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ a.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ b.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ self.assertIsInstance(f.exception(), RuntimeError)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py
new file mode 100644
index 0000000..fce2e6f
--- /dev/null
+++ b/Lib/test/test_asyncio/test_transports.py
@@ -0,0 +1,55 @@
+"""Tests for transports.py."""
+
+import unittest
+import unittest.mock
+
+from asyncio import transports
+
+
+class TransportTests(unittest.TestCase):
+
+ def test_ctor_extra_is_none(self):
+ transport = transports.Transport()
+ self.assertEqual(transport._extra, {})
+
+ def test_get_extra_info(self):
+ transport = transports.Transport({'extra': 'info'})
+ self.assertEqual('info', transport.get_extra_info('extra'))
+ self.assertIsNone(transport.get_extra_info('unknown'))
+
+ default = object()
+ self.assertIs(default, transport.get_extra_info('unknown', default))
+
+ def test_writelines(self):
+ transport = transports.Transport()
+ transport.write = unittest.mock.Mock()
+
+ transport.writelines(['line1', 'line2', 'line3'])
+ self.assertEqual(3, transport.write.call_count)
+
+ def test_not_implemented(self):
+ transport = transports.Transport()
+
+ self.assertRaises(NotImplementedError, transport.write, 'data')
+ self.assertRaises(NotImplementedError, transport.write_eof)
+ self.assertRaises(NotImplementedError, transport.can_write_eof)
+ self.assertRaises(NotImplementedError, transport.pause)
+ self.assertRaises(NotImplementedError, transport.resume)
+ self.assertRaises(NotImplementedError, transport.close)
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_dgram_not_implemented(self):
+ transport = transports.DatagramTransport()
+
+ self.assertRaises(NotImplementedError, transport.sendto, 'data')
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_subprocess_transport_not_implemented(self):
+ transport = transports.SubprocessTransport()
+
+ self.assertRaises(NotImplementedError, transport.get_pid)
+ self.assertRaises(NotImplementedError, transport.get_returncode)
+ self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
+ self.assertRaises(NotImplementedError, transport.send_signal, 1)
+ self.assertRaises(NotImplementedError, transport.terminate)
+ self.assertRaises(NotImplementedError, transport.kill)
diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py
new file mode 100644
index 0000000..ea67862
--- /dev/null
+++ b/Lib/test/test_asyncio/test_unix_events.py
@@ -0,0 +1,767 @@
+"""Tests for unix_events.py."""
+
+import gc
+import errno
+import io
+import pprint
+import signal
+import stat
+import sys
+import unittest
+import unittest.mock
+
+
+from asyncio import events
+from asyncio import futures
+from asyncio import protocols
+from asyncio import test_utils
+from asyncio import unix_events
+
+
+@unittest.skipUnless(signal, 'Signals are not supported')
+class SelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = unix_events.SelectorEventLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_check_signal(self):
+ self.assertRaises(
+ TypeError, self.loop._check_signal, '1')
+ self.assertRaises(
+ ValueError, self.loop._check_signal, signal.NSIG + 1)
+
+ def test_handle_signal_no_handler(self):
+ self.loop._handle_signal(signal.NSIG + 1, ())
+
+ def test_handle_signal_cancelled_handler(self):
+ h = events.Handle(unittest.mock.Mock(), ())
+ h.cancel()
+ self.loop._signal_handlers[signal.NSIG + 1] = h
+ self.loop.remove_signal_handler = unittest.mock.Mock()
+ self.loop._handle_signal(signal.NSIG + 1, ())
+ self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_setup_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ cb = lambda: True
+ self.loop.add_signal_handler(signal.SIGHUP, cb)
+ h = self.loop._signal_handlers.get(signal.SIGHUP)
+ self.assertTrue(isinstance(h, events.Handle))
+ self.assertEqual(h._callback, cb)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_install_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ def set_wakeup_fd(fd):
+ if fd == -1:
+ raise ValueError()
+ m_signal.set_wakeup_fd = set_wakeup_fd
+
+ class Err(OSError):
+ errno = errno.EFAULT
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ Err,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ def test_add_signal_handler_install_error2(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.loop._signal_handlers[signal.SIGHUP] = lambda: True
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ def test_add_signal_handler_install_error3(self, m_logging, m_signal):
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+ m_signal.NSIG = signal.NSIG
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGHUP))
+ self.assertTrue(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.SIGINT = signal.SIGINT
+
+ self.loop.add_signal_handler(signal.SIGINT, lambda: True)
+ self.loop._signal_handlers[signal.SIGHUP] = object()
+ m_signal.set_wakeup_fd.reset_mock()
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertFalse(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGINT, m_signal.default_int_handler),
+ m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.loop.remove_signal_handler(signal.SIGHUP)
+ self.assertTrue(m_logging.info)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.signal.side_effect = OSError
+
+ self.assertRaises(
+ OSError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = True
+ m_WIFSIGNALED.return_value = False
+ m_WEXITSTATUS.return_value = 3
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ transp._process_exited.assert_called_with(3)
+ self.assertFalse(m_WTERMSIG.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = False
+ m_WIFSIGNALED.return_value = True
+ m_WTERMSIG.return_value = 1
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ transp._process_exited.assert_called_with(-1)
+ self.assertFalse(m_WEXITSTATUS.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(0, object()), ChildProcessError]
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ self.assertFalse(transp._process_exited.called)
+ self.assertFalse(m_WIFSIGNALED.called)
+ self.assertFalse(m_WIFEXITED.called)
+ self.assertFalse(m_WTERMSIG.called)
+ self.assertFalse(m_WEXITSTATUS.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_not_registered_subprocess(self, m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = True
+ m_WIFSIGNALED.return_value = False
+ m_WEXITSTATUS.return_value = 3
+
+ self.loop._sig_chld()
+ self.assertFalse(m_WTERMSIG.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_unknown_status(self, m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = False
+ m_WIFSIGNALED.return_value = False
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ self.assertFalse(transp._process_exited.called)
+ self.assertFalse(m_WEXITSTATUS.called)
+ self.assertFalse(m_WTERMSIG.called)
+
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_unknown_status_in_handler(self, m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG,
+ m_log):
+ m_waitpid.side_effect = Exception
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ self.assertFalse(transp._process_exited.called)
+ self.assertFalse(m_WIFSIGNALED.called)
+ self.assertFalse(m_WIFEXITED.called)
+ self.assertFalse(m_WTERMSIG.called)
+ self.assertFalse(m_WEXITSTATUS.called)
+ m_log.exception.assert_called_with(
+ 'Unknown exception in SIGCHLD handler')
+
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_process_error(self, m_waitpid):
+ m_waitpid.side_effect = ChildProcessError
+ self.loop._sig_chld()
+ self.assertTrue(m_waitpid.called)
+
+
+class UnixReadPipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(protocols.Protocol)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ fcntl_patcher = unittest.mock.patch('fcntl.fcntl')
+ fcntl_patcher.start()
+ self.addCleanup(fcntl_patcher.stop)
+
+ def test_ctor(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+ unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.return_value = b'data'
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready_eof(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.return_value = b''
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.eof_received.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready_blocked(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.side_effect = BlockingIOError
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.data_received.called)
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ @unittest.mock.patch('os.read')
+ def test__read_ready_error(self, m_read, m_logexc):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ err = OSError()
+ m_read.side_effect = err
+ tr._close = unittest.mock.Mock()
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ tr._close.assert_called_with(err)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+
+ @unittest.mock.patch('os.read')
+ def test_pause(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m = unittest.mock.Mock()
+ self.loop.add_reader(5, m)
+ tr.pause()
+ self.assertFalse(self.loop.readers)
+
+ @unittest.mock.patch('os.read')
+ def test_resume(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.resume()
+ self.loop.assert_reader(5, tr._read_ready)
+
+ @unittest.mock.patch('os.read')
+ def test_close(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ tr._close.assert_called_with(None)
+
+ @unittest.mock.patch('os.read')
+ def test_close_already_closing(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr._closing = True
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ self.assertFalse(tr._close.called)
+
+ @unittest.mock.patch('os.read')
+ def test__close(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = object()
+ tr._close(err)
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ def test__call_connection_lost(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test__call_connection_lost_with_err(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+
+class UnixWritePipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ fcntl_patcher = unittest.mock.patch('fcntl.fcntl')
+ fcntl_patcher.start()
+ self.addCleanup(fcntl_patcher.stop)
+
+ fstat_patcher = unittest.mock.patch('os.fstat')
+ m_fstat = fstat_patcher.start()
+ st = unittest.mock.Mock()
+ st.st_mode = stat.S_IFIFO
+ m_fstat.return_value = st
+ self.addCleanup(fstat_patcher.stop)
+
+ def test_ctor(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol, fut)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(None, fut.result())
+
+ def test_can_write_eof(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+
+ @unittest.mock.patch('os.write')
+ def test_write(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.return_value = 4
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_no_data(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write(b'')
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_partial(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.return_value = 2
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'ta'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_buffer(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'previous']
+ tr.write(b'data')
+ self.assertFalse(m_write.called)
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'previous', b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_again(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.side_effect = BlockingIOError()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ @unittest.mock.patch('os.write')
+ def test_write_err(self, m_write, m_log):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ m_write.side_effect = err
+ tr._fatal_error = unittest.mock.Mock()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ tr._fatal_error.assert_called_with(err)
+ self.assertEqual(1, tr._conn_lost)
+
+ tr.write(b'data')
+ self.assertEqual(2, tr._conn_lost)
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ # This is a bit overspecified. :-(
+ m_log.warning.assert_called_with(
+ 'pipe closed by peer or os.write(pipe, data) raised exception.')
+
+ @unittest.mock.patch('os.write')
+ def test_write_close(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ tr._read_ready() # pipe was closed by peer
+
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 1)
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 2)
+
+ def test__read_ready(self):
+ tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe,
+ self.protocol)
+ tr._read_ready()
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertTrue(tr._closing)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_partial(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 3
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'a'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_again(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = BlockingIOError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_empty(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 0
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ @unittest.mock.patch('os.write')
+ def test__write_ready_err(self, m_write, m_logexc):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = err = OSError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+ self.assertEqual(1, tr._conn_lost)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_closing(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._closing = True
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.pipe.close.assert_called_with()
+
+ @unittest.mock.patch('os.write')
+ def test_abort(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ self.loop.add_reader(5, tr._read_ready)
+ tr._buffer = [b'da', b'ta']
+ tr.abort()
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test__call_connection_lost(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test__call_connection_lost_with_err(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test_close(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr.close()
+ tr.write_eof.assert_called_with()
+
+ def test_close_closing(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr._closing = True
+ tr.close()
+ self.assertFalse(tr.write_eof.called)
+
+ def test_write_eof(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_eof_pending(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ tr._buffer = [b'data']
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.protocol.connection_lost.called)
diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py
new file mode 100644
index 0000000..4b04073
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_events.py
@@ -0,0 +1,95 @@
+import os
+import sys
+import unittest
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import asyncio
+
+from asyncio import windows_events
+from asyncio import protocols
+from asyncio import streams
+from asyncio import transports
+from asyncio import test_utils
+
+
+class UpperProto(protocols.Protocol):
+ def __init__(self):
+ self.buf = []
+
+ def connection_made(self, trans):
+ self.trans = trans
+
+ def data_received(self, data):
+ self.buf.append(data)
+ if b'\n' in data:
+ self.trans.write(b''.join(self.buf).upper())
+ self.trans.close()
+
+
+class ProactorTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = windows_events.ProactorEventLoop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ self.loop = None
+
+ def test_close(self):
+ a, b = self.loop._socketpair()
+ trans = self.loop._make_socket_transport(a, protocols.Protocol())
+ f = asyncio.async(self.loop.sock_recv(b, 100))
+ trans.close()
+ self.loop.run_until_complete(f)
+ self.assertEqual(f.result(), b'')
+
+ def test_double_bind(self):
+ ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid()
+ server1 = windows_events.PipeServer(ADDRESS)
+ with self.assertRaises(PermissionError):
+ server2 = windows_events.PipeServer(ADDRESS)
+ server1.close()
+
+ def test_pipe(self):
+ res = self.loop.run_until_complete(self._test_pipe())
+ self.assertEqual(res, 'done')
+
+ def _test_pipe(self):
+ ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ protocols.Protocol, ADDRESS)
+
+ [server] = yield from self.loop.start_serving_pipe(
+ UpperProto, ADDRESS)
+ self.assertIsInstance(server, windows_events.PipeServer)
+
+ clients = []
+ for i in range(5):
+ stream_reader = streams.StreamReader(loop=self.loop)
+ protocol = streams.StreamReaderProtocol(stream_reader)
+ trans, proto = yield from self.loop.create_pipe_connection(
+ lambda:protocol, ADDRESS)
+ self.assertIsInstance(trans, transports.Transport)
+ self.assertEqual(protocol, proto)
+ clients.append((stream_reader, trans))
+
+ for i, (r, w) in enumerate(clients):
+ w.write('lower-{}\n'.format(i).encode())
+
+ for i, (r, w) in enumerate(clients):
+ response = yield from r.readline()
+ self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
+ w.close()
+
+ server.close()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ protocols.Protocol, ADDRESS)
+
+ return 'done'
diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py
new file mode 100644
index 0000000..4b96086
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_utils.py
@@ -0,0 +1,136 @@
+"""Tests for window_utils"""
+
+import sys
+import test.support
+import unittest
+import unittest.mock
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import _winapi
+
+from asyncio import windows_utils
+from asyncio import _overlapped
+
+
+class WinsocketpairTests(unittest.TestCase):
+
+ def test_winsocketpair(self):
+ ssock, csock = windows_utils.socketpair()
+
+ csock.send(b'xxx')
+ self.assertEqual(b'xxx', ssock.recv(1024))
+
+ csock.close()
+ ssock.close()
+
+ @unittest.mock.patch('asyncio.windows_utils.socket')
+ def test_winsocketpair_exc(self, m_socket):
+ m_socket.socket.return_value.getsockname.return_value = ('', 12345)
+ m_socket.socket.return_value.accept.return_value = object(), object()
+ m_socket.socket.return_value.connect.side_effect = OSError()
+
+ self.assertRaises(OSError, windows_utils.socketpair)
+
+
+class PipeTests(unittest.TestCase):
+
+ def test_pipe_overlapped(self):
+ h1, h2 = windows_utils.pipe(overlapped=(True, True))
+ try:
+ ov1 = _overlapped.Overlapped()
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, 0)
+
+ ov1.ReadFile(h1, 100)
+ self.assertTrue(ov1.pending)
+ self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING)
+ ERROR_IO_INCOMPLETE = 996
+ try:
+ ov1.getresult()
+ except OSError as e:
+ self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE)
+ else:
+ raise RuntimeError('expected ERROR_IO_INCOMPLETE')
+
+ ov2 = _overlapped.Overlapped()
+ self.assertFalse(ov2.pending)
+ self.assertEqual(ov2.error, 0)
+
+ ov2.WriteFile(h2, b"hello")
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+
+ res = _winapi.WaitForMultipleObjects([ov2.event], False, 100)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE)
+ self.assertFalse(ov2.pending)
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+ self.assertEqual(ov1.getresult(), b"hello")
+ finally:
+ _winapi.CloseHandle(h1)
+ _winapi.CloseHandle(h2)
+
+ def test_pipe_handle(self):
+ h, _ = windows_utils.pipe(overlapped=(True, True))
+ _winapi.CloseHandle(_)
+ p = windows_utils.PipeHandle(h)
+ self.assertEqual(p.fileno(), h)
+ self.assertEqual(p.handle, h)
+
+ # check garbage collection of p closes handle
+ del p
+ test.support.gc_collect()
+ try:
+ _winapi.CloseHandle(h)
+ except OSError as e:
+ self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE
+ else:
+ raise RuntimeError('expected ERROR_INVALID_HANDLE')
+
+
+class PopenTests(unittest.TestCase):
+
+ def test_popen(self):
+ command = r"""if 1:
+ import sys
+ s = sys.stdin.readline()
+ sys.stdout.write(s.upper())
+ sys.stderr.write('stderr')
+ """
+ msg = b"blah\n"
+
+ p = windows_utils.Popen([sys.executable, '-c', command],
+ stdin=windows_utils.PIPE,
+ stdout=windows_utils.PIPE,
+ stderr=windows_utils.PIPE)
+
+ for f in [p.stdin, p.stdout, p.stderr]:
+ self.assertIsInstance(f, windows_utils.PipeHandle)
+
+ ovin = _overlapped.Overlapped()
+ ovout = _overlapped.Overlapped()
+ overr = _overlapped.Overlapped()
+
+ ovin.WriteFile(p.stdin.handle, msg)
+ ovout.ReadFile(p.stdout.handle, 100)
+ overr.ReadFile(p.stderr.handle, 100)
+
+ events = [ovin.event, ovout.event, overr.event]
+ res = _winapi.WaitForMultipleObjects(events, True, 2000)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+ self.assertFalse(ovout.pending)
+ self.assertFalse(overr.pending)
+ self.assertFalse(ovin.pending)
+
+ self.assertEqual(ovin.getresult(), len(msg))
+ out = ovout.getresult().rstrip()
+ err = overr.getresult().rstrip()
+
+ self.assertGreater(len(out), 0)
+ self.assertGreater(len(err), 0)
+ # allow for partial reads...
+ self.assertTrue(msg.upper().rstrip().startswith(out))
+ self.assertTrue(b"stderr".startswith(err))
diff --git a/Lib/test/test_asyncio/tests.txt b/Lib/test/test_asyncio/tests.txt
new file mode 100644
index 0000000..e947721
--- /dev/null
+++ b/Lib/test/test_asyncio/tests.txt
@@ -0,0 +1,14 @@
+test_asyncio.test_base_events
+test_asyncio.test_events
+test_asyncio.test_futures
+test_asyncio.test_locks
+test_asyncio.test_proactor_events
+test_asyncio.test_queues
+test_asyncio.test_selector_events
+test_asyncio.test_selectors
+test_asyncio.test_streams
+test_asyncio.test_tasks
+test_asyncio.test_transports
+test_asyncio.test_unix_events
+test_asyncio.test_windows_events
+test_asyncio.test_windows_utils