diff options
author | Dennis Sweeney <36520290+sweeneyde@users.noreply.github.com> | 2021-04-11 04:51:35 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-11 04:51:35 (GMT) |
commit | dfb45323ce8a543ca844c311e32c994ec9554c1b (patch) | |
tree | af6944feb928d3b37ad71e69df1e8da9f59a81ce /Lib/test | |
parent | 9045919bfa820379a66ea67219f79ef6d9ecab49 (diff) | |
download | cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.zip cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.gz cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.bz2 |
bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/test_asyncgen.py | 140 |
1 files changed, 135 insertions, 5 deletions
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 99464e3..77c15c0 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -372,11 +372,8 @@ class AsyncGenAsyncioTest(unittest.TestCase): self.loop = None asyncio.set_event_loop_policy(None) - def test_async_gen_anext(self): - async def gen(): - yield 1 - yield 2 - g = gen() + def check_async_iterator_anext(self, ait_class): + g = ait_class() async def consume(): results = [] results.append(await anext(g)) @@ -388,6 +385,66 @@ class AsyncGenAsyncioTest(unittest.TestCase): with self.assertRaises(StopAsyncIteration): self.loop.run_until_complete(consume()) + async def test_2(): + g1 = ait_class() + self.assertEqual(await anext(g1), 1) + self.assertEqual(await anext(g1), 2) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + + g2 = ait_class() + self.assertEqual(await anext(g2, "default"), 1) + self.assertEqual(await anext(g2, "default"), 2) + self.assertEqual(await anext(g2, "default"), "default") + self.assertEqual(await anext(g2, "default"), "default") + + return "completed" + + result = self.loop.run_until_complete(test_2()) + self.assertEqual(result, "completed") + + def test_async_generator_anext(self): + async def agen(): + yield 1 + yield 2 + self.check_async_iterator_anext(agen) + + def test_python_async_iterator_anext(self): + class MyAsyncIter: + """Asynchronously yield 1, then 2.""" + def __init__(self): + self.yielded = 0 + def __aiter__(self): + return self + async def __anext__(self): + if self.yielded >= 2: + raise StopAsyncIteration() + else: + self.yielded += 1 + return self.yielded + self.check_async_iterator_anext(MyAsyncIter) + + def test_python_async_iterator_types_coroutine_anext(self): + import types + class MyAsyncIterWithTypesCoro: + """Asynchronously yield 1, then 2.""" + def __init__(self): + self.yielded = 0 + def __aiter__(self): + return self + @types.coroutine + def __anext__(self): + if False: + yield "this is a generator-based coroutine" + if self.yielded >= 2: + raise StopAsyncIteration() + else: + self.yielded += 1 + return self.yielded + self.check_async_iterator_anext(MyAsyncIterWithTypesCoro) + def test_async_gen_aiter(self): async def gen(): yield 1 @@ -431,12 +488,85 @@ class AsyncGenAsyncioTest(unittest.TestCase): await anext(gen(), 1, 3) async def call_with_wrong_type_args(): await anext(1, gen()) + async def call_with_kwarg(): + await anext(aiterator=gen()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_too_few_args()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_too_many_args()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_wrong_type_args()) + with self.assertRaises(TypeError): + self.loop.run_until_complete(call_with_kwarg()) + + def test_anext_bad_await(self): + async def bad_awaitable(): + class BadAwaitable: + def __await__(self): + return 42 + class MyAsyncIter: + def __aiter__(self): + return self + def __anext__(self): + return BadAwaitable() + regex = r"__await__.*iterator" + awaitable = anext(MyAsyncIter(), "default") + with self.assertRaisesRegex(TypeError, regex): + await awaitable + awaitable = anext(MyAsyncIter()) + with self.assertRaisesRegex(TypeError, regex): + await awaitable + return "completed" + result = self.loop.run_until_complete(bad_awaitable()) + self.assertEqual(result, "completed") + + async def check_anext_returning_iterator(self, aiter_class): + awaitable = anext(aiter_class(), "default") + with self.assertRaises(TypeError): + await awaitable + awaitable = anext(aiter_class()) + with self.assertRaises(TypeError): + await awaitable + return "completed" + + def test_anext_return_iterator(self): + class WithIterAnext: + def __aiter__(self): + return self + def __anext__(self): + return iter("abc") + result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext)) + self.assertEqual(result, "completed") + + def test_anext_return_generator(self): + class WithGenAnext: + def __aiter__(self): + return self + def __anext__(self): + yield + result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext)) + self.assertEqual(result, "completed") + + def test_anext_await_raises(self): + class RaisingAwaitable: + def __await__(self): + raise ZeroDivisionError() + yield + class WithRaisingAwaitableAnext: + def __aiter__(self): + return self + def __anext__(self): + return RaisingAwaitable() + async def do_test(): + awaitable = anext(WithRaisingAwaitableAnext()) + with self.assertRaises(ZeroDivisionError): + await awaitable + awaitable = anext(WithRaisingAwaitableAnext(), "default") + with self.assertRaises(ZeroDivisionError): + await awaitable + return "completed" + result = self.loop.run_until_complete(do_test()) + self.assertEqual(result, "completed") def test_aiter_bad_args(self): async def gen(): |