summaryrefslogtreecommitdiffstats
path: root/Objects/iterobject.c
diff options
context:
space:
mode:
authorDennis Sweeney <36520290+sweeneyde@users.noreply.github.com>2021-04-11 04:51:35 (GMT)
committerGitHub <noreply@github.com>2021-04-11 04:51:35 (GMT)
commitdfb45323ce8a543ca844c311e32c994ec9554c1b (patch)
treeaf6944feb928d3b37ad71e69df1e8da9f59a81ce /Objects/iterobject.c
parent9045919bfa820379a66ea67219f79ef6d9ecab49 (diff)
downloadcpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.zip
cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.gz
cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.bz2
bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
Diffstat (limited to 'Objects/iterobject.c')
-rw-r--r--Objects/iterobject.c47
1 files changed, 46 insertions, 1 deletions
diff --git a/Objects/iterobject.c b/Objects/iterobject.c
index 65af18a..6961fc3 100644
--- a/Objects/iterobject.c
+++ b/Objects/iterobject.c
@@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
static PyObject *
anextawaitable_iternext(anextawaitableobject *obj)
{
- PyObject *result = PyIter_Next(obj->wrapped);
+ /* Consider the following class:
+ *
+ * class A:
+ * async def __anext__(self):
+ * ...
+ * a = A()
+ *
+ * Then `await anext(a)` should call
+ * a.__anext__().__await__().__next__()
+ *
+ * On the other hand, given
+ *
+ * async def agen():
+ * yield 1
+ * yield 2
+ * gen = agen()
+ *
+ * Then `await anext(gen)` can just call
+ * gen.__anext__().__next__()
+ */
+ assert(obj->wrapped != NULL);
+ PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
+ if (awaitable == NULL) {
+ return NULL;
+ }
+ if (Py_TYPE(awaitable)->tp_iternext == NULL) {
+ /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
+ * or an iterator. Of these, only coroutines lack tp_iternext.
+ */
+ assert(PyCoro_CheckExact(awaitable));
+ unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
+ PyObject *new_awaitable = getter(awaitable);
+ if (new_awaitable == NULL) {
+ Py_DECREF(awaitable);
+ return NULL;
+ }
+ Py_SETREF(awaitable, new_awaitable);
+ if (Py_TYPE(awaitable)->tp_iternext == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "__await__ returned a non-iterable");
+ Py_DECREF(awaitable);
+ return NULL;
+ }
+ }
+ PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
+ Py_DECREF(awaitable);
if (result != NULL) {
return result;
}