summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_asyncgen.py172
1 files changed, 172 insertions, 0 deletions
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index bc0ae8f..473bce4 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -1,12 +1,16 @@
import inspect
import types
import unittest
+import contextlib
from test.support.import_helper import import_module
from test.support import gc_collect
asyncio = import_module("asyncio")
+_no_default = object()
+
+
class AwaitException(Exception):
pass
@@ -45,6 +49,37 @@ def to_list(gen):
return run_until_complete(iterate())
+def py_anext(iterator, default=_no_default):
+ """Pure-Python implementation of anext() for testing purposes.
+
+ Closely matches the builtin anext() C implementation.
+ Can be used to compare the built-in implementation of the inner
+ coroutines machinery to C-implementation of __anext__() and send()
+ or throw() on the returned generator.
+ """
+
+ try:
+ __anext__ = type(iterator).__anext__
+ except AttributeError:
+ raise TypeError(f'{iterator!r} is not an async iterator')
+
+ if default is _no_default:
+ return __anext__(iterator)
+
+ async def anext_impl():
+ try:
+ # The C code is way more low-level than this, as it implements
+ # all methods of the iterator protocol. In this implementation
+ # we're relying on higher-level coroutine concepts, but that's
+ # exactly what we want -- crosstest pure-Python high-level
+ # implementation and low-level C anext() iterators.
+ return await __anext__(iterator)
+ except StopAsyncIteration:
+ return default
+
+ return anext_impl()
+
+
class AsyncGenSyntaxTest(unittest.TestCase):
def test_async_gen_syntax_01(self):
@@ -374,6 +409,12 @@ class AsyncGenAsyncioTest(unittest.TestCase):
asyncio.set_event_loop_policy(None)
def check_async_iterator_anext(self, ait_class):
+ with self.subTest(anext="pure-Python"):
+ self._check_async_iterator_anext(ait_class, py_anext)
+ with self.subTest(anext="builtin"):
+ self._check_async_iterator_anext(ait_class, anext)
+
+ def _check_async_iterator_anext(self, ait_class, anext):
g = ait_class()
async def consume():
results = []
@@ -406,6 +447,24 @@ class AsyncGenAsyncioTest(unittest.TestCase):
result = self.loop.run_until_complete(test_2())
self.assertEqual(result, "completed")
+ def test_send():
+ p = ait_class()
+ obj = anext(p, "completed")
+ with self.assertRaises(StopIteration):
+ with contextlib.closing(obj.__await__()) as g:
+ g.send(None)
+
+ test_send()
+
+ async def test_throw():
+ p = ait_class()
+ obj = anext(p, "completed")
+ self.assertRaises(SyntaxError, obj.throw, SyntaxError)
+ return "completed"
+
+ result = self.loop.run_until_complete(test_throw())
+ self.assertEqual(result, "completed")
+
def test_async_generator_anext(self):
async def agen():
yield 1
@@ -569,6 +628,119 @@ class AsyncGenAsyncioTest(unittest.TestCase):
result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed")
+ def test_anext_iter(self):
+ @types.coroutine
+ def _async_yield(v):
+ return (yield v)
+
+ class MyError(Exception):
+ pass
+
+ async def agenfn():
+ try:
+ await _async_yield(1)
+ except MyError:
+ await _async_yield(2)
+ return
+ yield
+
+ def test1(anext):
+ agen = agenfn()
+ with contextlib.closing(anext(agen, "default").__await__()) as g:
+ self.assertEqual(g.send(None), 1)
+ self.assertEqual(g.throw(MyError, MyError(), None), 2)
+ try:
+ g.send(None)
+ except StopIteration as e:
+ err = e
+ else:
+ self.fail('StopIteration was not raised')
+ self.assertEqual(err.value, "default")
+
+ def test2(anext):
+ agen = agenfn()
+ with contextlib.closing(anext(agen, "default").__await__()) as g:
+ self.assertEqual(g.send(None), 1)
+ self.assertEqual(g.throw(MyError, MyError(), None), 2)
+ with self.assertRaises(MyError):
+ g.throw(MyError, MyError(), None)
+
+ def test3(anext):
+ agen = agenfn()
+ with contextlib.closing(anext(agen, "default").__await__()) as g:
+ self.assertEqual(g.send(None), 1)
+ g.close()
+ with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
+ self.assertEqual(g.send(None), 1)
+
+ def test4(anext):
+ @types.coroutine
+ def _async_yield(v):
+ yield v * 10
+ return (yield (v * 10 + 1))
+
+ async def agenfn():
+ try:
+ await _async_yield(1)
+ except MyError:
+ await _async_yield(2)
+ return
+ yield
+
+ agen = agenfn()
+ with contextlib.closing(anext(agen, "default").__await__()) as g:
+ self.assertEqual(g.send(None), 10)
+ self.assertEqual(g.throw(MyError, MyError(), None), 20)
+ with self.assertRaisesRegex(MyError, 'val'):
+ g.throw(MyError, MyError('val'), None)
+
+ def test5(anext):
+ @types.coroutine
+ def _async_yield(v):
+ yield v * 10
+ return (yield (v * 10 + 1))
+
+ async def agenfn():
+ try:
+ await _async_yield(1)
+ except MyError:
+ return
+ yield 'aaa'
+
+ agen = agenfn()
+ with contextlib.closing(anext(agen, "default").__await__()) as g:
+ self.assertEqual(g.send(None), 10)
+ with self.assertRaisesRegex(StopIteration, 'default'):
+ g.throw(MyError, MyError(), None)
+
+ def test6(anext):
+ @types.coroutine
+ def _async_yield(v):
+ yield v * 10
+ return (yield (v * 10 + 1))
+
+ async def agenfn():
+ await _async_yield(1)
+ yield 'aaa'
+
+ agen = agenfn()
+ with contextlib.closing(anext(agen, "default").__await__()) as g:
+ with self.assertRaises(MyError):
+ g.throw(MyError, MyError(), None)
+
+ def run_test(test):
+ with self.subTest('pure-Python anext()'):
+ test(py_anext)
+ with self.subTest('builtin anext()'):
+ test(anext)
+
+ run_test(test1)
+ run_test(test2)
+ run_test(test3)
+ run_test(test4)
+ run_test(test5)
+ run_test(test6)
+
def test_aiter_bad_args(self):
async def gen():
yield 1