diff options
Diffstat (limited to 'Lib/test/test_coroutines.py')
-rw-r--r-- | Lib/test/test_coroutines.py | 862 |
1 files changed, 862 insertions, 0 deletions
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py new file mode 100644 index 0000000..2a45225 --- /dev/null +++ b/Lib/test/test_coroutines.py @@ -0,0 +1,862 @@ +import contextlib +import gc +import sys +import types +import unittest +import warnings +from test import support + + +class AsyncYieldFrom: + def __init__(self, obj): + self.obj = obj + + def __await__(self): + yield from self.obj + + +class AsyncYield: + def __init__(self, value): + self.value = value + + def __await__(self): + yield self.value + + +def run_async(coro): + assert coro.__class__ is types.GeneratorType + + buffer = [] + result = None + while True: + try: + buffer.append(coro.send(None)) + except StopIteration as ex: + result = ex.args[0] if ex.args else None + break + return buffer, result + + +@contextlib.contextmanager +def silence_coro_gc(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + yield + support.gc_collect() + + +class AsyncBadSyntaxTest(unittest.TestCase): + + def test_badsyntax_1(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async1 + + def test_badsyntax_2(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async2 + + def test_badsyntax_3(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async3 + + def test_badsyntax_4(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async4 + + def test_badsyntax_5(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async5 + + def test_badsyntax_6(self): + with self.assertRaisesRegex( + SyntaxError, "'yield' inside async function"): + + import test.badsyntax_async6 + + def test_badsyntax_7(self): + with self.assertRaisesRegex( + SyntaxError, "'yield from' inside async function"): + + import test.badsyntax_async7 + + def test_badsyntax_8(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async8 + + def test_badsyntax_9(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async9 + + +class CoroutineTest(unittest.TestCase): + + def test_gen_1(self): + def gen(): yield + self.assertFalse(hasattr(gen, '__await__')) + + def test_func_1(self): + async def foo(): + return 10 + + f = foo() + self.assertIsInstance(f, types.GeneratorType) + self.assertTrue(bool(foo.__code__.co_flags & 0x80)) + self.assertTrue(bool(foo.__code__.co_flags & 0x20)) + self.assertTrue(bool(f.gi_code.co_flags & 0x80)) + self.assertTrue(bool(f.gi_code.co_flags & 0x20)) + self.assertEqual(run_async(f), ([], 10)) + + def bar(): pass + self.assertFalse(bool(bar.__code__.co_flags & 0x80)) + + def test_func_2(self): + async def foo(): + raise StopIteration + + with self.assertRaisesRegex( + RuntimeError, "generator raised StopIteration"): + + run_async(foo()) + + def test_func_3(self): + async def foo(): + raise StopIteration + + with silence_coro_gc(): + self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$') + + def test_func_4(self): + async def foo(): + raise StopIteration + + check = lambda: self.assertRaisesRegex( + TypeError, "coroutine-objects do not support iteration") + + with check(): + list(foo()) + + with check(): + tuple(foo()) + + with check(): + sum(foo()) + + with check(): + iter(foo()) + + with check(): + next(foo()) + + with silence_coro_gc(), check(): + for i in foo(): + pass + + with silence_coro_gc(), check(): + [i for i in foo()] + + def test_func_5(self): + @types.coroutine + def bar(): + yield 1 + + async def foo(): + await bar() + + check = lambda: self.assertRaisesRegex( + TypeError, "coroutine-objects do not support iteration") + + with check(): + for el in foo(): pass + + # the following should pass without an error + for el in bar(): + self.assertEqual(el, 1) + self.assertEqual([el for el in bar()], [1]) + self.assertEqual(tuple(bar()), (1,)) + self.assertEqual(next(iter(bar())), 1) + + def test_func_6(self): + @types.coroutine + def bar(): + yield 1 + yield 2 + + async def foo(): + await bar() + + f = foo() + self.assertEquals(f.send(None), 1) + self.assertEquals(f.send(None), 2) + with self.assertRaises(StopIteration): + f.send(None) + + def test_func_7(self): + async def bar(): + return 10 + + def foo(): + yield from bar() + + with silence_coro_gc(), self.assertRaisesRegex( + TypeError, + "cannot 'yield from' a coroutine object from a generator"): + + list(foo()) + + def test_func_8(self): + @types.coroutine + def bar(): + return (yield from foo()) + + async def foo(): + return 'spam' + + self.assertEqual(run_async(bar()), ([], 'spam') ) + + def test_func_9(self): + async def foo(): pass + + with self.assertWarnsRegex( + RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"): + + foo() + support.gc_collect() + + def test_await_1(self): + + async def foo(): + await 1 + with self.assertRaisesRegex(TypeError, "object int can.t.*await"): + run_async(foo()) + + def test_await_2(self): + async def foo(): + await [] + with self.assertRaisesRegex(TypeError, "object list can.t.*await"): + run_async(foo()) + + def test_await_3(self): + async def foo(): + await AsyncYieldFrom([1, 2, 3]) + + self.assertEqual(run_async(foo()), ([1, 2, 3], None)) + + def test_await_4(self): + async def bar(): + return 42 + + async def foo(): + return await bar() + + self.assertEqual(run_async(foo()), ([], 42)) + + def test_await_5(self): + class Awaitable: + def __await__(self): + return + + async def foo(): + return (await Awaitable()) + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type"): + + run_async(foo()) + + def test_await_6(self): + class Awaitable: + def __await__(self): + return iter([52]) + + async def foo(): + return (await Awaitable()) + + self.assertEqual(run_async(foo()), ([52], None)) + + def test_await_7(self): + class Awaitable: + def __await__(self): + yield 42 + return 100 + + async def foo(): + return (await Awaitable()) + + self.assertEqual(run_async(foo()), ([42], 100)) + + def test_await_8(self): + class Awaitable: + pass + + async def foo(): + return (await Awaitable()) + + with self.assertRaisesRegex( + TypeError, "object Awaitable can't be used in 'await' expression"): + + run_async(foo()) + + def test_await_9(self): + def wrap(): + return bar + + async def bar(): + return 42 + + async def foo(): + b = bar() + + db = {'b': lambda: wrap} + + class DB: + b = wrap + + return (await bar() + await wrap()() + await db['b']()()() + + await bar() * 1000 + await DB.b()()) + + async def foo2(): + return -await bar() + + self.assertEqual(run_async(foo()), ([], 42168)) + self.assertEqual(run_async(foo2()), ([], -42)) + + def test_await_10(self): + async def baz(): + return 42 + + async def bar(): + return baz() + + async def foo(): + return await (await bar()) + + self.assertEqual(run_async(foo()), ([], 42)) + + def test_await_11(self): + def ident(val): + return val + + async def bar(): + return 'spam' + + async def foo(): + return ident(val=await bar()) + + async def foo2(): + return await bar(), 'ham' + + self.assertEqual(run_async(foo2()), ([], ('spam', 'ham'))) + + def test_await_12(self): + async def coro(): + return 'spam' + + class Awaitable: + def __await__(self): + return coro() + + async def foo(): + return await Awaitable() + + with self.assertRaisesRegex( + TypeError, "__await__\(\) returned a coroutine"): + + run_async(foo()) + + def test_await_13(self): + class Awaitable: + def __await__(self): + return self + + async def foo(): + return await Awaitable() + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type"): + + run_async(foo()) + + def test_with_1(self): + class Manager: + def __init__(self, name): + self.name = name + + async def __aenter__(self): + await AsyncYieldFrom(['enter-1-' + self.name, + 'enter-2-' + self.name]) + return self + + async def __aexit__(self, *args): + await AsyncYieldFrom(['exit-1-' + self.name, + 'exit-2-' + self.name]) + + if self.name == 'B': + return True + + + async def foo(): + async with Manager("A") as a, Manager("B") as b: + await AsyncYieldFrom([('managers', a.name, b.name)]) + 1/0 + + f = foo() + result, _ = run_async(f) + + self.assertEqual( + result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B', + ('managers', 'A', 'B'), + 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A'] + ) + + async def foo(): + async with Manager("A") as a, Manager("C") as c: + await AsyncYieldFrom([('managers', a.name, c.name)]) + 1/0 + + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + + def test_with_2(self): + class CM: + def __aenter__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aexit__'): + run_async(foo()) + + def test_with_3(self): + class CM: + def __aexit__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aenter__'): + run_async(foo()) + + def test_with_4(self): + class CM: + def __enter__(self): + pass + + def __exit__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aexit__'): + run_async(foo()) + + def test_with_5(self): + # While this test doesn't make a lot of sense, + # it's a regression test for an early bug with opcodes + # generation + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + pass + + async def func(): + async with CM(): + assert (1, ) == 1 + + with self.assertRaises(AssertionError): + run_async(func()) + + def test_with_6(self): + class CM: + def __aenter__(self): + return 123 + + def __aexit__(self, *e): + return 456 + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex( + TypeError, "object int can't be used in 'await' expression"): + # it's important that __aexit__ wasn't called + run_async(foo()) + + def test_with_7(self): + class CM: + async def __aenter__(self): + return self + + def __aexit__(self, *e): + return 456 + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex( + TypeError, "object int can't be used in 'await' expression"): + + run_async(foo()) + + def test_for_1(self): + aiter_calls = 0 + + class AsyncIter: + def __init__(self): + self.i = 0 + + async def __aiter__(self): + nonlocal aiter_calls + aiter_calls += 1 + return self + + async def __anext__(self): + self.i += 1 + + if not (self.i % 10): + await AsyncYield(self.i * 10) + + if self.i > 100: + raise StopAsyncIteration + + return self.i, self.i + + + buffer = [] + async def test1(): + async for i1, i2 in AsyncIter(): + buffer.append(i1 + i2) + + yielded, _ = run_async(test1()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 1) + self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) + self.assertEqual(buffer, [i*2 for i in range(1, 101)]) + + + buffer = [] + async def test2(): + nonlocal buffer + async for i in AsyncIter(): + buffer.append(i[0]) + if i[0] == 20: + break + else: + buffer.append('what?') + buffer.append('end') + + yielded, _ = run_async(test2()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 2) + self.assertEqual(yielded, [100, 200]) + self.assertEqual(buffer, [i for i in range(1, 21)] + ['end']) + + + buffer = [] + async def test3(): + nonlocal buffer + async for i in AsyncIter(): + if i[0] > 20: + continue + buffer.append(i[0]) + else: + buffer.append('what?') + buffer.append('end') + + yielded, _ = run_async(test3()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 3) + self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) + self.assertEqual(buffer, [i for i in range(1, 21)] + + ['what?', 'end']) + + def test_for_2(self): + tup = (1, 2, 3) + refs_before = sys.getrefcount(tup) + + async def foo(): + async for i in tup: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, "async for' requires an object.*__aiter__.*tuple"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(tup), refs_before) + + def test_for_3(self): + class I: + def __aiter__(self): + return self + + aiter = I() + refs_before = sys.getrefcount(aiter) + + async def foo(): + async for i in aiter: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__aiter.*\: I"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(aiter), refs_before) + + def test_for_4(self): + class I: + async def __aiter__(self): + return self + + def __anext__(self): + return () + + aiter = I() + refs_before = sys.getrefcount(aiter) + + async def foo(): + async for i in aiter: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__anext__.*tuple"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(aiter), refs_before) + + def test_for_5(self): + class I: + async def __aiter__(self): + return self + + def __anext__(self): + return 123 + + async def foo(): + async for i in I(): + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__anext.*int"): + + run_async(foo()) + + def test_for_6(self): + I = 0 + + class Manager: + async def __aenter__(self): + nonlocal I + I += 10000 + + async def __aexit__(self, *args): + nonlocal I + I += 100000 + + class Iterable: + def __init__(self): + self.i = 0 + + async def __aiter__(self): + return self + + async def __anext__(self): + if self.i > 10: + raise StopAsyncIteration + self.i += 1 + return self.i + + ############## + + manager = Manager() + iterable = Iterable() + mrefs_before = sys.getrefcount(manager) + irefs_before = sys.getrefcount(iterable) + + async def main(): + nonlocal I + + async with manager: + async for i in iterable: + I += 1 + I += 1000 + + run_async(main()) + self.assertEqual(I, 111011) + + self.assertEqual(sys.getrefcount(manager), mrefs_before) + self.assertEqual(sys.getrefcount(iterable), irefs_before) + + ############## + + async def main(): + nonlocal I + + async with Manager(): + async for i in Iterable(): + I += 1 + I += 1000 + + async with Manager(): + async for i in Iterable(): + I += 1 + I += 1000 + + run_async(main()) + self.assertEqual(I, 333033) + + ############## + + async def main(): + nonlocal I + + async with Manager(): + I += 100 + async for i in Iterable(): + I += 1 + else: + I += 10000000 + I += 1000 + + async with Manager(): + I += 100 + async for i in Iterable(): + I += 1 + else: + I += 10000000 + I += 1000 + + run_async(main()) + self.assertEqual(I, 20555255) + + +class CoroAsyncIOCompatTest(unittest.TestCase): + + def test_asyncio_1(self): + import asyncio + + class MyException(Exception): + pass + + buffer = [] + + class CM: + async def __aenter__(self): + buffer.append(1) + await asyncio.sleep(0.01) + buffer.append(2) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await asyncio.sleep(0.01) + buffer.append(exc_type.__name__) + + async def f(): + async with CM() as c: + await asyncio.sleep(0.01) + raise MyException + buffer.append('unreachable') + + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(f()) + except MyException: + pass + finally: + loop.close() + + self.assertEqual(buffer, [1, 2, 'MyException']) + + +class SysSetCoroWrapperTest(unittest.TestCase): + + def test_set_wrapper_1(self): + async def foo(): + return 'spam' + + wrapped = None + def wrap(gen): + nonlocal wrapped + wrapped = gen + return gen + + self.assertIsNone(sys.get_coroutine_wrapper()) + + sys.set_coroutine_wrapper(wrap) + self.assertIs(sys.get_coroutine_wrapper(), wrap) + try: + f = foo() + self.assertTrue(wrapped) + + self.assertEqual(run_async(f), ([], 'spam')) + finally: + sys.set_coroutine_wrapper(None) + + self.assertIsNone(sys.get_coroutine_wrapper()) + + wrapped = None + with silence_coro_gc(): + foo() + self.assertFalse(wrapped) + + def test_set_wrapper_2(self): + self.assertIsNone(sys.get_coroutine_wrapper()) + with self.assertRaisesRegex(TypeError, "callable expected, got int"): + sys.set_coroutine_wrapper(1) + self.assertIsNone(sys.get_coroutine_wrapper()) + + +class CAPITest(unittest.TestCase): + + def test_tp_await_1(self): + from _testcapi import awaitType as at + + async def foo(): + future = at(iter([1])) + return (await future) + + self.assertEqual(foo().send(None), 1) + + def test_tp_await_2(self): + # Test tp_await to __await__ mapping + from _testcapi import awaitType as at + future = at(iter([1])) + self.assertEqual(next(future.__await__()), 1) + + def test_tp_await_3(self): + from _testcapi import awaitType as at + + async def foo(): + future = at(1) + return (await future) + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type 'int'"): + self.assertEqual(foo().send(None), 1) + + +def test_main(): + support.run_unittest(AsyncBadSyntaxTest, + CoroutineTest, + CoroAsyncIOCompatTest, + SysSetCoroWrapperTest, + CAPITest) + + +if __name__=="__main__": + test_main() |