diff options
author | Dennis Sweeney <36520290+sweeneyde@users.noreply.github.com> | 2021-04-11 04:51:35 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-11 04:51:35 (GMT) |
commit | dfb45323ce8a543ca844c311e32c994ec9554c1b (patch) | |
tree | af6944feb928d3b37ad71e69df1e8da9f59a81ce /Objects/iterobject.c | |
parent | 9045919bfa820379a66ea67219f79ef6d9ecab49 (diff) | |
download | cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.zip cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.gz cpython-dfb45323ce8a543ca844c311e32c994ec9554c1b.tar.bz2 |
bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
Diffstat (limited to 'Objects/iterobject.c')
-rw-r--r-- | Objects/iterobject.c | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 65af18a..6961fc3 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg) static PyObject * anextawaitable_iternext(anextawaitableobject *obj) { - PyObject *result = PyIter_Next(obj->wrapped); + /* Consider the following class: + * + * class A: + * async def __anext__(self): + * ... + * a = A() + * + * Then `await anext(a)` should call + * a.__anext__().__await__().__next__() + * + * On the other hand, given + * + * async def agen(): + * yield 1 + * yield 2 + * gen = agen() + * + * Then `await anext(gen)` can just call + * gen.__anext__().__next__() + */ + assert(obj->wrapped != NULL); + PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped); + if (awaitable == NULL) { + return NULL; + } + if (Py_TYPE(awaitable)->tp_iternext == NULL) { + /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator, + * or an iterator. Of these, only coroutines lack tp_iternext. + */ + assert(PyCoro_CheckExact(awaitable)); + unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await; + PyObject *new_awaitable = getter(awaitable); + if (new_awaitable == NULL) { + Py_DECREF(awaitable); + return NULL; + } + Py_SETREF(awaitable, new_awaitable); + if (Py_TYPE(awaitable)->tp_iternext == NULL) { + PyErr_SetString(PyExc_TypeError, + "__await__ returned a non-iterable"); + Py_DECREF(awaitable); + return NULL; + } + } + PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable); + Py_DECREF(awaitable); if (result != NULL) { return result; } |