summaryrefslogtreecommitdiffstats
path: root/Lib/test
diff options
context:
space:
mode:
authorDennis Sweeney <36520290+sweeneyde@users.noreply.github.com>2021-04-11 04:51:35 (GMT)
committerGitHub <noreply@github.com>2021-04-11 04:51:35 (GMT)
commitdfb45323ce8a543ca844c311e32c994ec9554c1b (patch)
treeaf6944feb928d3b37ad71e69df1e8da9f59a81ce /Lib/test
parent9045919bfa820379a66ea67219f79ef6d9ecab49 (diff)
downloadcpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.zip
cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.gz
cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.bz2
bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/test_asyncgen.py140
1 files changed, 135 insertions, 5 deletions
diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index 99464e3..77c15c0 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -372,11 +372,8 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.loop = None
asyncio.set_event_loop_policy(None)
- def test_async_gen_anext(self):
- async def gen():
- yield 1
- yield 2
- g = gen()
+ def check_async_iterator_anext(self, ait_class):
+ g = ait_class()
async def consume():
results = []
results.append(await anext(g))
@@ -388,6 +385,66 @@ class AsyncGenAsyncioTest(unittest.TestCase):
with self.assertRaises(StopAsyncIteration):
self.loop.run_until_complete(consume())
+ async def test_2():
+ g1 = ait_class()
+ self.assertEqual(await anext(g1), 1)
+ self.assertEqual(await anext(g1), 2)
+ with self.assertRaises(StopAsyncIteration):
+ await anext(g1)
+ with self.assertRaises(StopAsyncIteration):
+ await anext(g1)
+
+ g2 = ait_class()
+ self.assertEqual(await anext(g2, "default"), 1)
+ self.assertEqual(await anext(g2, "default"), 2)
+ self.assertEqual(await anext(g2, "default"), "default")
+ self.assertEqual(await anext(g2, "default"), "default")
+
+ return "completed"
+
+ result = self.loop.run_until_complete(test_2())
+ self.assertEqual(result, "completed")
+
+ def test_async_generator_anext(self):
+ async def agen():
+ yield 1
+ yield 2
+ self.check_async_iterator_anext(agen)
+
+ def test_python_async_iterator_anext(self):
+ class MyAsyncIter:
+ """Asynchronously yield 1, then 2."""
+ def __init__(self):
+ self.yielded = 0
+ def __aiter__(self):
+ return self
+ async def __anext__(self):
+ if self.yielded >= 2:
+ raise StopAsyncIteration()
+ else:
+ self.yielded += 1
+ return self.yielded
+ self.check_async_iterator_anext(MyAsyncIter)
+
+ def test_python_async_iterator_types_coroutine_anext(self):
+ import types
+ class MyAsyncIterWithTypesCoro:
+ """Asynchronously yield 1, then 2."""
+ def __init__(self):
+ self.yielded = 0
+ def __aiter__(self):
+ return self
+ @types.coroutine
+ def __anext__(self):
+ if False:
+ yield "this is a generator-based coroutine"
+ if self.yielded >= 2:
+ raise StopAsyncIteration()
+ else:
+ self.yielded += 1
+ return self.yielded
+ self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
+
def test_async_gen_aiter(self):
async def gen():
yield 1
@@ -431,12 +488,85 @@ class AsyncGenAsyncioTest(unittest.TestCase):
await anext(gen(), 1, 3)
async def call_with_wrong_type_args():
await anext(1, gen())
+ async def call_with_kwarg():
+ await anext(aiterator=gen())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_few_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_many_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_wrong_type_args())
+ with self.assertRaises(TypeError):
+ self.loop.run_until_complete(call_with_kwarg())
+
+ def test_anext_bad_await(self):
+ async def bad_awaitable():
+ class BadAwaitable:
+ def __await__(self):
+ return 42
+ class MyAsyncIter:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ return BadAwaitable()
+ regex = r"__await__.*iterator"
+ awaitable = anext(MyAsyncIter(), "default")
+ with self.assertRaisesRegex(TypeError, regex):
+ await awaitable
+ awaitable = anext(MyAsyncIter())
+ with self.assertRaisesRegex(TypeError, regex):
+ await awaitable
+ return "completed"
+ result = self.loop.run_until_complete(bad_awaitable())
+ self.assertEqual(result, "completed")
+
+ async def check_anext_returning_iterator(self, aiter_class):
+ awaitable = anext(aiter_class(), "default")
+ with self.assertRaises(TypeError):
+ await awaitable
+ awaitable = anext(aiter_class())
+ with self.assertRaises(TypeError):
+ await awaitable
+ return "completed"
+
+ def test_anext_return_iterator(self):
+ class WithIterAnext:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ return iter("abc")
+ result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
+ self.assertEqual(result, "completed")
+
+ def test_anext_return_generator(self):
+ class WithGenAnext:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ yield
+ result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
+ self.assertEqual(result, "completed")
+
+ def test_anext_await_raises(self):
+ class RaisingAwaitable:
+ def __await__(self):
+ raise ZeroDivisionError()
+ yield
+ class WithRaisingAwaitableAnext:
+ def __aiter__(self):
+ return self
+ def __anext__(self):
+ return RaisingAwaitable()
+ async def do_test():
+ awaitable = anext(WithRaisingAwaitableAnext())
+ with self.assertRaises(ZeroDivisionError):
+ await awaitable
+ awaitable = anext(WithRaisingAwaitableAnext(), "default")
+ with self.assertRaises(ZeroDivisionError):
+ await awaitable
+ return "completed"
+ result = self.loop.run_until_complete(do_test())
+ self.assertEqual(result, "completed")
def test_aiter_bad_args(self):
async def gen():