diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_asyncgen.py | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index bc0ae8f..473bce4 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -1,12 +1,16 @@ import inspect import types import unittest +import contextlib from test.support.import_helper import import_module from test.support import gc_collect asyncio = import_module("asyncio") +_no_default = object() + + class AwaitException(Exception): pass @@ -45,6 +49,37 @@ def to_list(gen): return run_until_complete(iterate()) +def py_anext(iterator, default=_no_default): + """Pure-Python implementation of anext() for testing purposes. + + Closely matches the builtin anext() C implementation. + Can be used to compare the built-in implementation of the inner + coroutines machinery to C-implementation of __anext__() and send() + or throw() on the returned generator. + """ + + try: + __anext__ = type(iterator).__anext__ + except AttributeError: + raise TypeError(f'{iterator!r} is not an async iterator') + + if default is _no_default: + return __anext__(iterator) + + async def anext_impl(): + try: + # The C code is way more low-level than this, as it implements + # all methods of the iterator protocol. In this implementation + # we're relying on higher-level coroutine concepts, but that's + # exactly what we want -- crosstest pure-Python high-level + # implementation and low-level C anext() iterators. + return await __anext__(iterator) + except StopAsyncIteration: + return default + + return anext_impl() + + class AsyncGenSyntaxTest(unittest.TestCase): def test_async_gen_syntax_01(self): @@ -374,6 +409,12 @@ class AsyncGenAsyncioTest(unittest.TestCase): asyncio.set_event_loop_policy(None) def check_async_iterator_anext(self, ait_class): + with self.subTest(anext="pure-Python"): + self._check_async_iterator_anext(ait_class, py_anext) + with self.subTest(anext="builtin"): + self._check_async_iterator_anext(ait_class, anext) + + def _check_async_iterator_anext(self, ait_class, anext): g = ait_class() async def consume(): results = [] @@ -406,6 +447,24 @@ class AsyncGenAsyncioTest(unittest.TestCase): result = self.loop.run_until_complete(test_2()) self.assertEqual(result, "completed") + def test_send(): + p = ait_class() + obj = anext(p, "completed") + with self.assertRaises(StopIteration): + with contextlib.closing(obj.__await__()) as g: + g.send(None) + + test_send() + + async def test_throw(): + p = ait_class() + obj = anext(p, "completed") + self.assertRaises(SyntaxError, obj.throw, SyntaxError) + return "completed" + + result = self.loop.run_until_complete(test_throw()) + self.assertEqual(result, "completed") + def test_async_generator_anext(self): async def agen(): yield 1 @@ -569,6 +628,119 @@ class AsyncGenAsyncioTest(unittest.TestCase): result = self.loop.run_until_complete(do_test()) self.assertEqual(result, "completed") + def test_anext_iter(self): + @types.coroutine + def _async_yield(v): + return (yield v) + + class MyError(Exception): + pass + + async def agenfn(): + try: + await _async_yield(1) + except MyError: + await _async_yield(2) + return + yield + + def test1(anext): + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 1) + self.assertEqual(g.throw(MyError, MyError(), None), 2) + try: + g.send(None) + except StopIteration as e: + err = e + else: + self.fail('StopIteration was not raised') + self.assertEqual(err.value, "default") + + def test2(anext): + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 1) + self.assertEqual(g.throw(MyError, MyError(), None), 2) + with self.assertRaises(MyError): + g.throw(MyError, MyError(), None) + + def test3(anext): + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 1) + g.close() + with self.assertRaisesRegex(RuntimeError, 'cannot reuse'): + self.assertEqual(g.send(None), 1) + + def test4(anext): + @types.coroutine + def _async_yield(v): + yield v * 10 + return (yield (v * 10 + 1)) + + async def agenfn(): + try: + await _async_yield(1) + except MyError: + await _async_yield(2) + return + yield + + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 10) + self.assertEqual(g.throw(MyError, MyError(), None), 20) + with self.assertRaisesRegex(MyError, 'val'): + g.throw(MyError, MyError('val'), None) + + def test5(anext): + @types.coroutine + def _async_yield(v): + yield v * 10 + return (yield (v * 10 + 1)) + + async def agenfn(): + try: + await _async_yield(1) + except MyError: + return + yield 'aaa' + + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 10) + with self.assertRaisesRegex(StopIteration, 'default'): + g.throw(MyError, MyError(), None) + + def test6(anext): + @types.coroutine + def _async_yield(v): + yield v * 10 + return (yield (v * 10 + 1)) + + async def agenfn(): + await _async_yield(1) + yield 'aaa' + + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + with self.assertRaises(MyError): + g.throw(MyError, MyError(), None) + + def run_test(test): + with self.subTest('pure-Python anext()'): + test(py_anext) + with self.subTest('builtin anext()'): + test(anext) + + run_test(test1) + run_test(test2) + run_test(test3) + run_test(test4) + run_test(test5) + run_test(test6) + def test_aiter_bad_args(self): async def gen(): yield 1 |