import contextlib import gc import sys import types import unittest import warnings from test import support class AsyncYieldFrom: def __init__(self, obj): self.obj = obj def __await__(self): yield from self.obj class AsyncYield: def __init__(self, value): self.value = value def __await__(self): yield self.value def run_async(coro): assert coro.__class__ is types.GeneratorType buffer = [] result = None while True: try: buffer.append(coro.send(None)) except StopIteration as ex: result = ex.args[0] if ex.args else None break return buffer, result @contextlib.contextmanager def silence_coro_gc(): with warnings.catch_warnings(): warnings.simplefilter("ignore") yield support.gc_collect() class AsyncBadSyntaxTest(unittest.TestCase): def test_badsyntax_1(self): with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): import test.badsyntax_async1 def test_badsyntax_2(self): with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): import test.badsyntax_async2 def test_badsyntax_3(self): with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): import test.badsyntax_async3 def test_badsyntax_4(self): with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): import test.badsyntax_async4 def test_badsyntax_5(self): with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): import test.badsyntax_async5 def test_badsyntax_6(self): with self.assertRaisesRegex( SyntaxError, "'yield' inside async function"): import test.badsyntax_async6 def test_badsyntax_7(self): with self.assertRaisesRegex( SyntaxError, "'yield from' inside async function"): import test.badsyntax_async7 def test_badsyntax_8(self): with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): import test.badsyntax_async8 def test_badsyntax_9(self): with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): import test.badsyntax_async9 class CoroutineTest(unittest.TestCase): def test_gen_1(self): def gen(): yield self.assertFalse(hasattr(gen, '__await__')) def test_func_1(self): async def foo(): return 10 f = foo() self.assertIsInstance(f, types.GeneratorType) self.assertTrue(bool(foo.__code__.co_flags & 0x80)) self.assertTrue(bool(foo.__code__.co_flags & 0x20)) self.assertTrue(bool(f.gi_code.co_flags & 0x80)) self.assertTrue(bool(f.gi_code.co_flags & 0x20)) self.assertEqual(run_async(f), ([], 10)) def bar(): pass self.assertFalse(bool(bar.__code__.co_flags & 0x80)) def test_func_2(self): async def foo(): raise StopIteration with self.assertRaisesRegex( RuntimeError, "generator raised StopIteration"): run_async(foo()) def test_func_3(self): async def foo(): raise StopIteration with silence_coro_gc(): self.assertRegex(repr(foo()), '^$') def test_func_4(self): async def foo(): raise StopIteration check = lambda: self.assertRaisesRegex( TypeError, "coroutine-objects do not support iteration") with check(): list(foo()) with check(): tuple(foo()) with check(): sum(foo()) with check(): iter(foo()) with check(): next(foo()) with silence_coro_gc(), check(): for i in foo(): pass with silence_coro_gc(), check(): [i for i in foo()] def test_func_5(self): @types.coroutine def bar(): yield 1 async def foo(): await bar() check = lambda: self.assertRaisesRegex( TypeError, "coroutine-objects do not support iteration") with check(): for el in foo(): pass # the following should pass without an error for el in bar(): self.assertEqual(el, 1) self.assertEqual([el for el in bar()], [1]) self.assertEqual(tuple(bar()), (1,)) self.assertEqual(next(iter(bar())), 1) def test_func_6(self): @types.coroutine def bar(): yield 1 yield 2 async def foo(): await bar() f = foo() self.assertEquals(f.send(None), 1) self.assertEquals(f.send(None), 2) with self.assertRaises(StopIteration): f.send(None) def test_func_7(self): async def bar(): return 10 def foo(): yield from bar() with silence_coro_gc(), self.assertRaisesRegex( TypeError, "cannot 'yield from' a coroutine object from a generator"): list(foo()) def test_func_8(self): @types.coroutine def bar(): return (yield from foo()) async def foo(): return 'spam' self.assertEqual(run_async(bar()), ([], 'spam') ) def test_func_9(self): async def foo(): pass with self.assertWarnsRegex( RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"): foo() support.gc_collect() def test_await_1(self): async def foo(): await 1 with self.assertRaisesRegex(TypeError, "object int can.t.*await"): run_async(foo()) def test_await_2(self): async def foo(): await [] with self.assertRaisesRegex(TypeError, "object list can.t.*await"): run_async(foo()) def test_await_3(self): async def foo(): await AsyncYieldFrom([1, 2, 3]) self.assertEqual(run_async(foo()), ([1, 2, 3], None)) def test_await_4(self): async def bar(): return 42 async def foo(): return await bar() self.assertEqual(run_async(foo()), ([], 42)) def test_await_5(self): class Awaitable: def __await__(self): return async def foo(): return (await Awaitable()) with self.assertRaisesRegex( TypeError, "__await__.*returned non-iterator of type"): run_async(foo()) def test_await_6(self): class Awaitable: def __await__(self): return iter([52]) async def foo(): return (await Awaitable()) self.assertEqual(run_async(foo()), ([52], None)) def test_await_7(self): class Awaitable: def __await__(self): yield 42 return 100 async def foo(): return (await Awaitable()) self.assertEqual(run_async(foo()), ([42], 100)) def test_await_8(self): class Awaitable: pass async def foo(): return (await Awaitable()) with self.assertRaisesRegex( TypeError, "object Awaitable can't be used in 'await' expression"): run_async(foo()) def test_await_9(self): def wrap(): return bar async def bar(): return 42 async def foo(): b = bar() db = {'b': lambda: wrap} class DB: b = wrap return (await bar() + await wrap()() + await db['b']()()() + await bar() * 1000 + await DB.b()()) async def foo2(): return -await bar() self.assertEqual(run_async(foo()), ([], 42168)) self.assertEqual(run_async(foo2()), ([], -42)) def test_await_10(self): async def baz(): return 42 async def bar(): return baz() async def foo(): return await (await bar()) self.assertEqual(run_async(foo()), ([], 42)) def test_await_11(self): def ident(val): return val async def bar(): return 'spam' async def foo(): return ident(val=await bar()) async def foo2(): return await bar(), 'ham' self.assertEqual(run_async(foo2()), ([], ('spam', 'ham'))) def test_await_12(self): async def coro(): return 'spam' class Awaitable: def __await__(self): return coro() async def foo(): return await Awaitable() with self.assertRaisesRegex( TypeError, "__await__\(\) returned a coroutine"): run_async(foo()) def test_await_13(self): class Awaitable: def __await__(self): return self async def foo(): return await Awaitable() with self.assertRaisesRegex( TypeError, "__await__.*returned non-iterator of type"): run_async(foo()) def test_with_1(self): class Manager: def __init__(self, name): self.name = name async def __aenter__(self): await AsyncYieldFrom(['enter-1-' + self.name, 'enter-2-' + self.name]) return self async def __aexit__(self, *args): await AsyncYieldFrom(['exit-1-' + self.name, 'exit-2-' + self.name]) if self.name == 'B': return True async def foo(): async with Manager("A") as a, Manager("B") as b: await AsyncYieldFrom([('managers', a.name, b.name)]) 1/0 f = foo() result, _ = run_async(f) self.assertEqual( result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B', ('managers', 'A', 'B'), 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A'] ) async def foo(): async with Manager("A") as a, Manager("C") as c: await AsyncYieldFrom([('managers', a.name, c.name)]) 1/0 with self.assertRaises(ZeroDivisionError): run_async(foo()) def test_with_2(self): class CM: def __aenter__(self): pass async def foo(): async with CM(): pass with self.assertRaisesRegex(AttributeError, '__aexit__'): run_async(foo()) def test_with_3(self): class CM: def __aexit__(self): pass async def foo(): async with CM(): pass with self.assertRaisesRegex(AttributeError, '__aenter__'): run_async(foo()) def test_with_4(self): class CM: def __enter__(self): pass def __exit__(self): pass async def foo(): async with CM(): pass with self.assertRaisesRegex(AttributeError, '__aexit__'): run_async(foo()) def test_with_5(self): # While this test doesn't make a lot of sense, # it's a regression test for an early bug with opcodes # generation class CM: async def __aenter__(self): return self async def __aexit__(self, *exc): pass async def func(): async with CM(): assert (1, ) == 1 with self.assertRaises(AssertionError): run_async(func()) def test_with_6(self): class CM: def __aenter__(self): return 123 def __aexit__(self, *e): return 456 async def foo(): async with CM(): pass with self.assertRaisesRegex( TypeError, "object int can't be used in 'await' expression"): # it's important that __aexit__ wasn't called run_async(foo()) def test_with_7(self): class CM: async def __aenter__(self): return self def __aexit__(self, *e): return 456 async def foo(): async with CM(): pass with self.assertRaisesRegex( TypeError, "object int can't be used in 'await' expression"): run_async(foo()) def test_for_1(self): aiter_calls = 0 class AsyncIter: def __init__(self): self.i = 0 async def __aiter__(self): nonlocal aiter_calls aiter_calls += 1 return self async def __anext__(self): self.i += 1 if not (self.i % 10): await AsyncYield(self.i * 10) if self.i > 100: raise StopAsyncIteration return self.i, self.i buffer = [] async def test1(): async for i1, i2 in AsyncIter(): buffer.append(i1 + i2) yielded, _ = run_async(test1()) # Make sure that __aiter__ was called only once self.assertEqual(aiter_calls, 1) self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) self.assertEqual(buffer, [i*2 for i in range(1, 101)]) buffer = [] async def test2(): nonlocal buffer async for i in AsyncIter(): buffer.append(i[0]) if i[0] == 20: break else: buffer.append('what?') buffer.append('end') yielded, _ = run_async(test2()) # Make sure that __aiter__ was called only once self.assertEqual(aiter_calls, 2) self.assertEqual(yielded, [100, 200]) self.assertEqual(buffer, [i for i in range(1, 21)] + ['end']) buffer = [] async def test3(): nonlocal buffer async for i in AsyncIter(): if i[0] > 20: continue buffer.append(i[0]) else: buffer.append('what?') buffer.append('end') yielded, _ = run_async(test3()) # Make sure that __aiter__ was called only once self.assertEqual(aiter_calls, 3) self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) self.assertEqual(buffer, [i for i in range(1, 21)] + ['what?', 'end']) def test_for_2(self): tup = (1, 2, 3) refs_before = sys.getrefcount(tup) async def foo(): async for i in tup: print('never going to happen') with self.assertRaisesRegex( TypeError, "async for' requires an object.*__aiter__.*tuple"): run_async(foo()) self.assertEqual(sys.getrefcount(tup), refs_before) def test_for_3(self): class I: def __aiter__(self): return self aiter = I() refs_before = sys.getrefcount(aiter) async def foo(): async for i in aiter: print('never going to happen') with self.assertRaisesRegex( TypeError, "async for' received an invalid object.*__aiter.*\: I"): run_async(foo()) self.assertEqual(sys.getrefcount(aiter), refs_before) def test_for_4(self): class I: async def __aiter__(self): return self def __anext__(self): return () aiter = I() refs_before = sys.getrefcount(aiter) async def foo(): async for i in aiter: print('never going to happen') with self.assertRaisesRegex( TypeError, "async for' received an invalid object.*__anext__.*tuple"): run_async(foo()) self.assertEqual(sys.getrefcount(aiter), refs_before) def test_for_5(self): class I: async def __aiter__(self): return self def __anext__(self): return 123 async def foo(): async for i in I(): print('never going to happen') with self.assertRaisesRegex( TypeError, "async for' received an invalid object.*__anext.*int"): run_async(foo()) def test_for_6(self): I = 0 class Manager: async def __aenter__(self): nonlocal I I += 10000 async def __aexit__(self, *args): nonlocal I I += 100000 class Iterable: def __init__(self): self.i = 0 async def __aiter__(self): return self async def __anext__(self): if self.i > 10: raise StopAsyncIteration self.i += 1 return self.i ############## manager = Manager() iterable = Iterable() mrefs_before = sys.getrefcount(manager) irefs_before = sys.getrefcount(iterable) async def main(): nonlocal I async with manager: async for i in iterable: I += 1 I += 1000 run_async(main()) self.assertEqual(I, 111011) self.assertEqual(sys.getrefcount(manager), mrefs_before) self.assertEqual(sys.getrefcount(iterable), irefs_before) ############## async def main(): nonlocal I async with Manager(): async for i in Iterable(): I += 1 I += 1000 async with Manager(): async for i in Iterable(): I += 1 I += 1000 run_async(main()) self.assertEqual(I, 333033) ############## async def main(): nonlocal I async with Manager(): I += 100 async for i in Iterable(): I += 1 else: I += 10000000 I += 1000 async with Manager(): I += 100 async for i in Iterable(): I += 1 else: I += 10000000 I += 1000 run_async(main()) self.assertEqual(I, 20555255) class CoroAsyncIOCompatTest(unittest.TestCase): def test_asyncio_1(self): import asyncio class MyException(Exception): pass buffer = [] class CM: async def __aenter__(self): buffer.append(1) await asyncio.sleep(0.01) buffer.append(2) return self async def __aexit__(self, exc_type, exc_val, exc_tb): await asyncio.sleep(0.01) buffer.append(exc_type.__name__) async def f(): async with CM() as c: await asyncio.sleep(0.01) raise MyException buffer.append('unreachable') loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(f()) except MyException: pass finally: loop.close() asyncio.set_event_loop(None) self.assertEqual(buffer, [1, 2, 'MyException']) class SysSetCoroWrapperTest(unittest.TestCase): def test_set_wrapper_1(self): async def foo(): return 'spam' wrapped = None def wrap(gen): nonlocal wrapped wrapped = gen return gen self.assertIsNone(sys.get_coroutine_wrapper()) sys.set_coroutine_wrapper(wrap) self.assertIs(sys.get_coroutine_wrapper(), wrap) try: f = foo() self.assertTrue(wrapped) self.assertEqual(run_async(f), ([], 'spam')) finally: sys.set_coroutine_wrapper(None) self.assertIsNone(sys.get_coroutine_wrapper()) wrapped = None with silence_coro_gc(): foo() self.assertFalse(wrapped) def test_set_wrapper_2(self): self.assertIsNone(sys.get_coroutine_wrapper()) with self.assertRaisesRegex(TypeError, "callable expected, got int"): sys.set_coroutine_wrapper(1) self.assertIsNone(sys.get_coroutine_wrapper()) class CAPITest(unittest.TestCase): def test_tp_await_1(self): from _testcapi import awaitType as at async def foo(): future = at(iter([1])) return (await future) self.assertEqual(foo().send(None), 1) def test_tp_await_2(self): # Test tp_await to __await__ mapping from _testcapi import awaitType as at future = at(iter([1])) self.assertEqual(next(future.__await__()), 1) def test_tp_await_3(self): from _testcapi import awaitType as at async def foo(): future = at(1) return (await future) with self.assertRaisesRegex( TypeError, "__await__.*returned non-iterator of type 'int'"): self.assertEqual(foo().send(None), 1) def test_main(): support.run_unittest(AsyncBadSyntaxTest, CoroutineTest, CoroAsyncIOCompatTest, SysSetCoroWrapperTest, CAPITest) if __name__=="__main__": test_main()